Source code for towhee.engine.operator_loader

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import sys
import subprocess
from pathlib import Path
from typing import Any, List, Dict, Union

import pkg_resources
from pkg_resources import DistributionNotFound
from towhee.operator import Operator
from towhee.operator.nop import NOPOperator
from towhee.operator.concat_operator import ConcatOperator
from towhee.engine import LOCAL_OPERATOR_CACHE
from towhee.hub.file_manager import FileManager
from towhee.hparam import param_scope

from .operator_registry import OperatorRegistry

[docs]class OperatorLoader: """Wrapper class used to load operators from either local cache or a remote location. Args: cache_path: (`str`) Local cache path to use. If not specified, it will default to `$HOME/.towhee/operators`. """
[docs] def __init__(self, cache_path: str = None): if cache_path is None: self._cache_path = LOCAL_OPERATOR_CACHE else: self._cache_path = Path(cache_path)
def _load_interal_op(self, op_name: str, arg: List[Any], kws: Dict[str, Any]): if op_name in ['_start_op', '_end_op']: return NOPOperator() elif op_name == '_concat': return ConcatOperator(*arg, **kws) else: # Not a interal operator return None def load_operator_from_internal(self, function: str, arg: List[Any], kws: Dict[str, Any], tag: str) -> Operator: # pylint: disable=unused-argument return self._load_interal_op(function, arg, kws) def load_operator_from_registry(self, function: str, arg: List[Any], kws: Dict[str, Any], tag: str) -> Operator: # pylint: disable=unused-argument op = OperatorRegistry.resolve(function) return self.instance_operator(op, arg, kws) if op is not None else None def load_operator_from_packages(self, function: str, arg: List[Any], kws: Dict[str, Any], tag: str) -> Operator: # pylint: disable=unused-argument try: module, fname = function.split('/') fname = fname.replace('-', '_') op_cls = ''.join(x.capitalize() or '_' for x in fname.split('_')) # module = '.'.join([module, fname, fname]) module = '.'.join(['towheeoperator', '{}_{}'.format(module, fname), fname]) op = getattr(importlib.import_module(module), op_cls) return self.instance_operator(op, arg, kws) if op is not None else None except Exception: # pylint: disable=broad-except with param_scope() as hp: if hp().towhee.hub.use_pip(False): return None # TODO: download and install pip package from hub else: return None
[docs] def load_operator_from_path(self, path: Union[str, Path], arg: List[Any], kws: Dict[str, Any]) -> Operator: """ Load operator form local path. Args: path (`Union[str, Path]`): Path to the operator python file. args (`Dict[str, Any]`): The init args for OperatorClass. Returns (`typing.Any`) The `Operator` output. """ path = Path(path) fname = Path(path).stem modname = 'towhee.operator.' + fname if 'requirements.txt' in (i.name for i in path.parent.iterdir()): with open(path.parent / 'requirements.txt', 'r', encoding='utf-8') as f: reqs = f.read().split('\n') for req in reqs: try: pkg_resources.require(req) except DistributionNotFound: subprocess.check_call([sys.executable, '-m', 'pip', 'install', req]) try: # support for ver1 operator API spec = importlib.util.spec_from_file_location(modname, path.resolve()) # Create the module and then execute the module in its own namespace. module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Instantiate the operator object and return it to the caller for # `load_operator`. By convention, the operator class is simply the CamelCase # version of the snake_case operator. op_cls = ''.join(x.capitalize() or '_' for x in fname.split('_')) op = getattr(module, op_cls) except Exception: # pylint: disable=broad-except # support for ver2 operator API path = path.parent / '__init__.py' spec = importlib.util.spec_from_file_location(modname, path.resolve()) module = importlib.util.module_from_spec(spec) sys.modules[modname] = module spec.loader.exec_module(module) op = getattr(module, fname) return self.instance_operator(op, arg, kws) if op is not None else None
def load_operator_from_cache(self, function: str, arg: List[Any], kws: Dict[str, Any], tag: str) -> Operator: # pylint: disable=unused-argument if '/' not in function: function = 'towhee/'+function try: fm = FileManager() path = fm.get_operator(operator=function, tag=tag) except ValueError as e: raise ValueError('operator {} not found!'.format(function)) from e if path is None: raise FileExistsError('Cannot find operator.') return self.load_operator_from_path(path, arg, kws)
[docs] def load_operator(self, function: str, arg: List[Any], kws: Dict[str, Any], tag: str) -> Operator: """Attempts to load an operator from cache. If it does not exist, looks up the operator in a remote location and downloads it to cache instead. By standard convention, the operator must be called `Operator` and all associated data must be contained within a single directory. Args: function: (`str`) Origin and method/class name of the operator. Used to look up the proper operator in cache. Raises: FileExistsError Cannot find operator. """ for factory in [self.load_operator_from_internal, self.load_operator_from_registry, self.load_operator_from_packages, self.load_operator_from_cache]: op = factory(function, arg, kws, tag) if op is not None: return op return None
def instance_operator(self, op, arg: List[Any], kws: Dict[str, Any]) -> Operator: with param_scope() as hp: if hp().towhee.dry_run(False): return NOPOperator() else: return op(*arg, **kws) if kws is not None else op()