Source code for towhee.runtime.factory

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=unused-import
# pylint: disable=dangerous-default-value

from typing import Any, Dict, Tuple, Callable

from towhee.runtime.operator_manager import OperatorLoader, OperatorRegistry
from towhee.utils.log import engine_log


class OpsParser:
    """
    Runtime parsing of unknown attributes.

    Example:
    >>> from towhee.runtime.factory import ops_parse
    >>> @ops_parse
    ... def foo(name, *args, **kwargs):
    ...     return str((name, args, kwargs))
    >>> print(foo.bar.zed(1, 2, 3))
    ('bar.zed', (1, 2, 3), {})
    """

    def __init__(self, func: Callable, name=None):
        self._func = func
        self._name = name

    def __call__(self, *args, **kws) -> Any:
        return self._func(self._name, *args, **kws)

    def __getattr__(self, name: str) -> Any:
        if self._name is not None:
            name = f'{self._name}.{name}'
        return ops_parse(self._func, name)


def ops_parse(func, name=None):
    """
    Wraps function with a class to allow __getattr__ on a function.
    """
    new_class = type(func.__name__, (
        OpsParser,
        object,
    ), dict(__doc__=func.__doc__))
    return new_class(func, name)


class _OperatorWrapper:
    """
    Operator wrapper for initialization.
    """

    def __init__(self,
                 name: str,
                 init_args: Tuple = None,
                 init_kws: Dict[str, Any] = None,
                 tag: str = 'main',
                 latest: bool = False):
        self._name = name.replace('.', '/').replace('_', '-')
        self._tag = tag
        self._latest = latest
        self._init_args = init_args
        self._init_kws = init_kws
        self._op = None

    @property
    def name(self):
        return self._name

    @property
    def tag(self):
        return self._tag

    @property
    def init_args(self):
        return self._init_args

    @property
    def init_kws(self):
        return self._init_kws

    @property
    def is_latest(self):
        return self._latest

    def revision(self, tag: str = 'main'):
        self._tag = tag
        return self

    def latest(self):
        self._latest = True
        return self

    def get_op(self):
        if self._op is None:
            self.preload_op()
        return self._op

    def preload_op(self):
        try:
            loader = OperatorLoader()
            self._op = loader.load_operator(self._name, self._init_args, self._init_kws, tag=self._tag, latest=self._latest)
        except Exception as e:
            err = f'Loading operator with error:{e}'
            engine_log.error(err)
            raise RuntimeError(err) from e

    def __call__(self, *args, **kws):
        if self._op is None:
            self.preload_op()

        result = self._op(*args, **kws)
        return result

    @staticmethod
    def callback(name: str, *args, **kws):
        return _OperatorWrapper(name, args, kws)


[docs]class _OperatorParser: """ Class to loading operator with _OperatorWrapper. An operator is usually referred to with its full name: namespace/name. Examples: >>> from towhee import ops >>> op = ops.towhee.image_decode() >>> img = op('./towhee_logo.png') We can also specify the version of the operator on the hub via the `revision` method: >>> op = ops.towhee.image_decode() And the `latest` method is used to update the current version of the operator to the latest: >>> op = ops.towhee.image_decode().latest() """ @classmethod def __getattr__(cls, name): @ops_parse def wrapper(name, *args, **kws): return _OperatorWrapper.callback(name, *args, **kws) return getattr(wrapper, name)
ops = _OperatorParser() register = OperatorRegistry.register