Source code for towhee.functional.mixins.state

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from towhee.hparam import param_scope
from towhee.hparam import HyperParameter


[docs]class StateMixin: """ Mixin for state tracking. Examples: >>> from towhee import DataCollection, State >>> from towhee import param_scope >>> dc = DataCollection.range(10).set_state(State(a=1)) >>> dc.get_state() {'a': 1} >>> dc = dc.map(lambda x: x+1).map(lambda x: x*2) >>> dc.get_state() {'a': 1} """ def __init__(self): super().__init__() with param_scope() as hp: parent = hp().data_collection.parent(None) if parent is not None: self.set_state(parent.get_state())
[docs] def get_state(self): """ Get the state storage for `DataCollection` Returns: State: the state storage """ if hasattr(self, '_state') and isinstance(self._state, HyperParameter): return self._state return None
[docs] def set_state(self, state): """ Set the state storage for `DataCollection` Args: state (State): state storage Returns: DataCollection: data collection itself """ self._state = state return self
[docs] def set_training(self, state=None): """ Set training mode for stateful operators Args: state (State, optional): Update the state storage. Defaults to None. Returns: DataCollection: data collection itself """ if state is not None: self._state = state if self.get_state() is None: self._state = HyperParameter() self._state.__mode__ = 'training' return self
[docs] def set_evaluating(self, state=None): """ Set evaluating mode for stateful operators Args: state (State, optional): Update the state storage. Defaults to None. Returns: DataCollection: data collection itself """ if state is not None: self._state = state if self.get_state() is None: self._state = HyperParameter() self._state.__mode__ = 'evaluating' return self
def smap(self, op): op.set_state(self._state) if self._state().__mode__('evaluating') == 'training': op.set_training(True) with param_scope() as hp: hp().towhee.data_collection.training = True for x in self: op(x) op.set_training(False) op.fit() if hasattr(self._iterable, 'apply') and hasattr(op, '__dataframe_apply__'): return self._factory(op.__dataframe_apply__(self._iterable)) return self._factory(map(op, self._iterable))