Source code for towhee.dag.operator_repr

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Any

from towhee.dag.base_repr import BaseRepr


[docs]class OperatorRepr(BaseRepr): """This class encapsulates operator representations at compile-time. Args: name (`str`): Name of the operator represented by this object. function (`str`): The path leads to the operator. init_args (`Dict[str, any]`): The args to initilize the operator. inputs (`List[Dict[str, Any]]`): Input dataframes(s) to this object. outputs (`List[Dict[str, Any]]`): This operator's output dataframe. iter_info (`Dict[str, Any]`): This operator's iterator info. """
[docs] def __init__( self, name: str, function: str, init_args: Dict[str, Any], inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], iter_info: Dict[str, Any], tag: str = 'main', threads: int = 1 ): super().__init__(name) self._function = function self._tag = tag self._inputs = inputs self._outputs = outputs self._init_args = init_args self._iter_info = iter_info self._threads = threads if self._threads != 1 and self._iter_info.get('type') not in ['filter', 'map']: raise RuntimeError('Only filter and map operator support multi-threads')
@property def function(self): return self._function @property def inputs(self) -> List[dict]: """ Returns: (`List[dict]`) The inputs of the operator. """ return self._inputs @property def outputs(self) -> List: """ Returns: (`List[dict]`) The outputs of the operator. """ return self._outputs @property def init_args(self) -> Dict[str, Any]: """ Returns: (`Dict[str, Any]`) The args to initilize the operator. """ return self._init_args @property def tag(self) -> str: """ Returns: (`str`) The tag to load of the operator. """ return self._tag @property def iter_info(self) -> Dict[str, Any]: """ Returns: (` Dict[str, Any]`) The operator's iterator info. """ return self._iter_info @property def threads(self) -> int: return self._threads
[docs] @staticmethod def from_dict(info: Dict[str, Any]) -> 'OperatorRepr': """ Generate a OperatorRepr from a description dict. Args: info (`Dict[str, Any]`): A dict to describe the Operator. Returns: (`towhee.dag.OperatorRepr`) The OperatorRepe object. """ if not BaseRepr.is_valid(info, {'name', 'init_args', 'function', 'inputs', 'outputs', 'iter_info'}): raise ValueError('Invalid operator info.') if 'tag' not in info: info['tag'] = 'main' return OperatorRepr(info['name'], info['function'], info['init_args'], info['inputs'], info['outputs'], info['iter_info'], info['tag'], info.get('threads', 1))
@staticmethod def from_ir(function: str, init_args: Dict[str, Any]) -> 'OperatorRepr': return OperatorRepr('', function, init_args, None, None, None)