Source code for towhee.operator.stateful_operator

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from towhee.hparam import param_scope
from towhee.operator import Operator


[docs]class StatefulOperator(Operator): """ Stateful operator. Examples: >>> from towhee import register >>> from towhee import DataCollection, State >>> from towhee.functional.entity import Entity >>> import numpy as np >>> @register ... class my_normalize(StatefulOperator): ... def __init__(self, name): ... super().__init__(name=name) ... def fit(self): ... self._state._mu = np.mean(self._data[0]) ... self._state._std = np.std(self._data[0]) ... def predict(self, x): ... return (x-self._state._mu)/self._state._std >>> dc = ( ... DataCollection.range(10) ... .set_training(State()) ... .map(lambda x: Entity(a=x)) ... .my_normalize['a', 'b'](name='mynorm') ... ) >>> [int(x.b*10) for x in dc.to_list()] [-15, -12, -8, -5, -1, 1, 5, 8, 12, 15] >>> dc._state.mynorm._mu 4.5 """
[docs] def __init__(self, name): super().__init__() self._name = name if name else self._get_default_name() self._data = None self._training = False
def _get_default_name(self): with param_scope() as hp: name = self.__class__.__name__.replace('_', '-') inputs = (hp.index[0] if isinstance(hp.index[0], str) else '-'.join(hp.index[0])).replace('_', '-') outputs = (hp.index[1] if isinstance(hp.index[1], str) else '-'.join(hp.index[1])).replace('_', '-') return '_'.join([name, inputs, outputs]) def set_state(self, state): if not getattr(state, self._name): getattr(state, self._name).model = 1 self._state = getattr(state, self._name) def set_training(self, flag): self._training = flag def feed(self, *arg): if self._data is None: self._data = [[x] for x in arg] return for i in range(len(arg)): self._data[i].append(arg[i]) def fit(self, *arg): if len(arg) == 0 and self._data: return self.fit(*self._data) def predict(self, *arg): pass
[docs] def __call__(self, *arg): if self._training: return self.feed(*arg) else: return self.predict(*arg)
if __name__ == '__main__': import doctest doctest.testmod(verbose=False)