Source code for towhee.operator.base

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from abc import ABC
from enum import Enum
from enum import Flag, auto


# NotShareable:
#    Stateful & reusable operator.

# NotRusable:
#    Stateful & not reusable operator.

# Shareable:
#    Stateless operator

SharedType = Enum('SharedType', ('NotShareable', 'NotReusable', 'Shareable'))


[docs]class OperatorFlag(Flag): EMPTYFLAG = auto() STATELESS = auto() REUSEABLE = auto()
[docs]class Operator(ABC): """ Operator base class, implements __init__ and __call__, Examples: class AddOperator(Operator): def __init__(self, factor: int): self._factor = factor def __call__(self, num) -> NamedTuple("Outputs", [("sum", int)]): Outputs = NamedTuple("Outputs", [("sum", int)]) return Outputs(self._factor + num) """
[docs] @abstractmethod def __init__(self): """ Init operator, before a graph starts, the framework will call Operator __init__ function. Args: Raises: An exception during __init__ can terminate the graph run. """ from towhee.runtime import get_sys_config # pylint: disable=import-outside-toplevel sys_config = get_sys_config() if sys_config is not None: self._device_id = sys_config.device_id else: self._device_id = -1 self._key = ''
[docs] @abstractmethod def __call__(self): """ The framework calls __call__ function repeatedly for every input data. Args: Returns: Raises: An exception during __init__ can terminate the graph run. """ raise NotImplementedError
@property def key(self): return self._key @property def shared_type(self): return SharedType.NotShareable @key.setter def key(self, value): self._key = value @property def flag(self): return OperatorFlag.STATELESS|OperatorFlag.REUSEABLE
[docs]class NNOperator(Operator): """ Neural Network related operators that involve machine learning frameworks. Args: framework (`str`): The framework to apply. """
[docs] def __init__(self, framework: str = 'pytorch'): super().__init__() self._framework = framework self.operator_name = type(self).__name__ self.model = None self.model_card = None self._trainer = None
@property def flag(self): return OperatorFlag.STATELESS|OperatorFlag.REUSEABLE @property def framework(self): return self._framework @framework.setter def framework(self, framework: str): self._framework = framework
[docs] def save_model(self): """ Save model to local. """ raise NotImplementedError
[docs] def supported_model_names(self): """ Return a list of supported model names. """ raise NotImplementedError
[docs] def train(self, training_config=None, train_dataset=None, eval_dataset=None, resume_checkpoint_path=None, **kwargs): """ Start to train an operator. Args: training_config (`TrainingConfig`): The config of this trainer. train_dataset (`Union[Dataset, TowheeDataSet]`): Training dataset. eval_dataset (`Union[Dataset, TowheeDataSet]`): Evaluate dataset. resume_checkpoint_path (`str`): If resuming training, pass into the path. **kwargs (`Any`): Keyword Args. """ self.setup_trainer(training_config, train_dataset, eval_dataset, **kwargs) self.trainer.train(resume_checkpoint_path)
@property def trainer(self) -> 'Trainer': if self._trainer is None: self.setup_trainer() return self._trainer @trainer.setter def trainer(self, trainer): self._trainer = trainer
[docs] def setup_trainer(self, training_config=None, train_dataset=None, eval_dataset=None, train_dataloader=None, eval_dataloader=None, model_card=None ): """ Set up the trainer instance in operator before training and set trainer parameters. Args: training_config (`TrainingConfig`): The config of this trainer. train_dataset (`Union[Dataset, TowheeDataSet]`): Training dataset. eval_dataset (`Union[Dataset, TowheeDataSet]`): Evaluate dataset. train_dataloader (`Union[DataLoader, Iterable]`): If specified, `Trainer` will use it to load training data. Otherwise, `Trainer` will build dataloader from train_dataset. eval_dataloader (`Union[DataLoader, Iterable]`): If specified, `Trainer` will use it to load evaluate data. Otherwise, `Trainer` will build dataloader from train_dataset. model_card (`ModelCard`): Model card contains the training informations. Returns: """ from towhee.trainer.trainer import Trainer # pylint: disable=import-outside-toplevel if self._trainer is None: self._trainer = Trainer(self.model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, model_card) else: if training_config is not None: self._trainer.configs = training_config if train_dataset is not None: self._trainer.train_dataset = train_dataset if eval_dataset is not None: self._trainer.eval_dataset = eval_dataset if train_dataloader is not None: self._trainer.train_dataloader = train_dataloader if eval_dataloader is not None: self._trainer.eval_dataloader = eval_dataloader if model_card is not None: self._trainer.model_card = model_card
[docs] def load(self, path: str = None): """ Load the model checkpoint into an operator. Args: path (`str`): The folder path containing the model's checkpoints. """ self.trainer.load(path)
[docs] def save(self, path: str, overwrite: bool = True): """ Save the model checkpoint into the path. Args: path (`str`): The folder path containing the model's checkpoints. overwrite (`bool`): If True, it will overwrite the same name path when existing. Raises: (`FileExistsError`) If `overwrite` is False, when there already exists a path, it will raise Error. """ self.trainer.save(path, overwrite)
[docs]class PyOperator(Operator): """ Python function operator, no machine learning frameworks involved. """
[docs] def __init__(self): super().__init__() pass
@property def flag(self): return OperatorFlag.EMPTYFLAG