Source code for towhee.functional.mixins.ray

from queue import Queue
import asyncio
import threading
import uuid
from towhee.hub.file_manager import FileManagerConfig

from towhee.utils.log import engine_log
from towhee.functional.option import Option, Empty, _Reason
from towhee.functional.mixins.parallel import EOS


def _map_task_ray(unary_op):  # pragma: no cover

    def map_wrapper(x):
        try:
            if isinstance(x, Option):
                return x.map(unary_op)
            else:
                return unary_op(x)
        except Exception as e:  # pylint: disable=broad-except
            engine_log.warning(f'{e}, please check {x} with op {unary_op}. Continue...')  # pylint: disable=logging-fstring-interpolation
            return Empty(_Reason(x, e))

    return map_wrapper


[docs]class RayMixin: # pragma: no cover #pylint: disable=import-outside-toplevel """ Mixin for parallel ray execution. """
[docs] def ray_start(self, address=None, local_packages: list = None, pip_packages: list = None, silence=True): """ Start the ray service. When using a remote cluster, all dependencies for custom functions and operators defined locally will need to be sent to the ray cluster. If using ray locally, within the runtime, avoid passing in any arguments. Args: address (str): The address for the ray service being connected to. If using ray cluster remotely with kubectl forwarded port, the most likely address will be "ray://localhost:10001". local_packages (list[str]): Whichever locally defined modules that are used within a custom function supplied to the pipeline, whether it be in lambda functions, locally registered operators, or functions themselves. pip_packages (list[str]): Whichever pip installed modules that are used within a custom function supplied to the pipeline, whether it be in lambda functions, locally registered operators, or functions themselves. """ import ray local_packages = [] if local_packages is None else local_packages pip_packages = [] if pip_packages is None else pip_packages if ('towhee' not in pip_packages and 'towhee' not in [str(x.__name__) for x in local_packages]) and (address is not None): pip_packages.append('towhee') runtime_env = {'py_modules': local_packages, 'pip': pip_packages} ray.init(address=address, runtime_env=runtime_env, ignore_reinit_error=True, log_to_driver=silence) self._backend_started = True return self
def ray_resolve(self, call_mapping, path, index, *arg, **kws): import ray #TODO: Make local functions work with ray if path in call_mapping: return self.map(call_mapping[path](*arg, **kws)) @ray.remote class OperatorActor: """Ray actor that runs hub operators.""" def __init__(self, path1, index1, uid, *arg1, **kws1): from towhee import engine from towhee.engine.factory import _OperatorLazyWrapper from pathlib import Path engine.DEFAULT_LOCAL_CACHE_ROOT = Path.home() / ( '.towhee/ray_actor_cache_' + uid) engine.LOCAL_PIPELINE_CACHE = engine.DEFAULT_LOCAL_CACHE_ROOT / 'pipelines' engine.LOCAL_OPERATOR_CACHE = engine.DEFAULT_LOCAL_CACHE_ROOT / 'operators' x = FileManagerConfig() x.update_default_cache(engine.DEFAULT_LOCAL_CACHE_ROOT) self.op = _OperatorLazyWrapper.callback( path1, index1, *arg1, **kws1) def __call__(self, *arg1, **kwargs1): return self.op(*arg1, **kwargs1) def cleanup(self): from shutil import rmtree from towhee import engine try: rmtree(engine.DEFAULT_LOCAL_CACHE_ROOT) except FileNotFoundError: pass actors = [ OperatorActor.remote(path, index, str(uuid.uuid4().hex[:12].upper()), *arg, **kws) for _ in range(self._num_worker) ] pool = ray.util.ActorPool(actors) queue = Queue(self._num_worker) def inner(): while True: x = queue.get() if isinstance(x, EOS): break else: yield x for x in actors: x.cleanup.remote() def worker(): for x in self: while pool.has_free() is False: if pool.has_next(): queue.put(pool.get_next()) pool.submit(lambda a, v: a.__call__.remote(v), x) while pool.has_next(): queue.put(pool.get_next()) queue.put(EOS()) t = threading.Thread(target=worker) t.start() child = self._factory(inner()) return child def _ray_pmap(self, unary_op, num_worker=None): import ray if num_worker is None and self._num_worker is None: num_worker = 2 else: num_worker = self._num_worker queue = Queue(num_worker) loop = asyncio.new_event_loop() def inner(): while True: x = queue.get() if isinstance(x, EOS): break else: yield x @ray.remote def remote_runner(val): return _map_task_ray(unary_op)(val) async def worker(): buff = [] for x in self: if len(buff) == num_worker: queue.put(await buff.pop(0)) buff.append( asyncio.wrap_future(remote_runner.remote(x).future())) while len(buff) > 0: queue.put(await buff.pop(0)) queue.put(EOS()) def worker_wrapper(): loop.run_until_complete(worker()) loop.close() t = threading.Thread(target=worker_wrapper) t.start() return self._factory(inner())