Source code for towhee.functional.mixins.parallel

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import asyncio
import threading
import time
from queue import Queue

try:
    import torch
except: # pylint: disable=bare-except
    pass

from towhee.utils.log import engine_log
from towhee.functional.option import Option, Empty, _Reason
from towhee.hparam.hyperparameter import param_scope
from towhee.functional.storages import WritableTable, ChunkedTable

stream = threading.local()


[docs]def initializer(): # pylint: disable=bare-except try: if torch.cuda.is_available(): stream.stream = torch.cuda.Stream() except: pass
[docs]class ParallelMixin: """ Mixin for parallel execution. Examples: >>> from towhee import DataCollection >>> def add_1(x): ... return x+1 >>> result = DataCollection.range(1000).set_parallel(2).map(add_1).to_list() >>> len(result) 1000 >>> from towhee import dc >>> dc = dc['a'](range(1000)).set_parallel(5) >>> dc = dc.runas_op['a', 'b'](lambda x: x+1).to_list() >>> len(dc) 1000 >>> from towhee import dc >>> dc = dc['a'](range(1000)).set_parallel(5).set_chunksize(2) >>> dc = dc.runas_op['a', 'b'](lambda x: x+1) >>> dc._iterable.chunks()[:2] [pyarrow.Table a: int64 b: int64 ---- a: [[0,1]] b: [[1,2]], pyarrow.Table a: int64 b: int64 ---- a: [[2,3]] b: [[3,4]]] >>> result = DataCollection.range(1000).pmap(add_1, 10).pmap(add_1, 10).to_list() >>> result[990:] [992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001] """
[docs] def __init__(self) -> None: super().__init__() with param_scope() as hp: parent = hp().data_collection.parent(None) if parent is not None and hasattr(parent, '_executor'): self._backend = parent._backend self._executor = parent._executor self._num_worker = parent._num_worker
def get_executor(self): if hasattr(self, '_executor'): return self._executor return None def get_backend(self): if hasattr(self, '_backend') and isinstance(self._backend, str): return self._backend return None def get_num_worker(self): if hasattr(self, '_num_worker'): return self._num_worker return None
[docs] def set_parallel(self, num_worker=2, backend='thread'): """ Set parallel execution for following calls. Examples: >>> from towhee import DataCollection >>> import threading >>> stage_1_thread_set = set() >>> stage_2_thread_set = set() >>> result = ( ... DataCollection.range(1000).stream().set_parallel(4) ... .map(lambda x: stage_1_thread_set.add(threading.current_thread().ident)) ... .map(lambda x: stage_2_thread_set.add(threading.current_thread().ident)).to_list() ... ) >>> len(stage_2_thread_set)>1 True """ self._backend = backend self._num_worker = num_worker if self._backend == 'thread' and self._num_worker is not None: self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_worker, initializer=initializer) else: # clear executor self._executor = None return self
[docs] def split(self, count): """ Split a dataframe into multiple dataframes. Args: count (int): how many resulting DCs; Returns: [DataCollection, ...]: copies of DC; Examples: 1. Split: >>> from towhee import DataCollection >>> dc = DataCollection([0, 1, 2, 3, 4]).stream() >>> a, b, c = dc.split(3) >>> a.zip(b, c).to_list() [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)] """ # Figure out better optimization if self.is_stream: queues = [Queue(count) for _ in range(count)] else: queues = [Queue() for _ in range(count)] loop = asyncio.new_event_loop() def inner(queue): while True: x = queue.get() if isinstance(x, EOS): break else: yield x async def worker(): cached_values = {x: [] for x in range(count)} for x in self: #TODO: Use some kind of event instead of wait sleepy = .01 while all(y.full() for y in queues): time.sleep(sleepy) for i, queue in enumerate(queues): if len(cached_values[i]) > 0: while not queue.full() and len(cached_values[i]) > 0: queue.put(cached_values[i].pop(0)) if len(cached_values[i]) == 0 and not queue.full(): queue.put(x) else: cached_values[i].append(x) for i, queue in enumerate(queues): poison = EOS() cached_values[i].append(poison) while len(cached_values) > 0: for x in list(cached_values.keys()): if len(cached_values[x]) == 0: del cached_values[x] else: while not queues[x].full() and len(cached_values[x]) > 0: queues[x].put(cached_values[x].pop(0)) def worker_wrapper(): loop.run_until_complete(worker()) loop.close() t = threading.Thread(target=worker_wrapper, daemon=True) t.start() retval = [inner(queue) for queue in queues] return [self._factory(x) for x in retval]
def _map_task(self, x, unary_op): def inner(): try: if isinstance(x, Option): res = x.map(unary_op) elif isinstance(x, WritableTable): res = WritableTable(self.__table_apply__(x, unary_op)) else: res = unary_op(x) return res except Exception as e: # pylint: disable=broad-except engine_log.warning(f'{e}, please check {x} with op {unary_op}. Continue...') # pylint: disable=logging-fstring-interpolation return Empty(_Reason(x, e)) def map_wrapper(): if hasattr(stream, 'stream'): torch.cuda.synchronize() with torch.cuda.stream(stream.stream): res = inner() torch.cuda.synchronize() return res else: return inner() return map_wrapper
[docs] def pmap(self, unary_op, num_worker=None, backend=None): """ Apply `unary_op` with parallel execution. Currently supports two backends, `ray` and `thread`. Args: unary_op (func): the op to be mapped; num_worker (int): how many threads to reserve for this op; backend (str): whether to use `ray` or `thread` Examples: >>> from towhee import DataCollection >>> import threading >>> stage_1_thread_set = {threading.current_thread().ident} >>> stage_2_thread_set = {threading.current_thread().ident} >>> result = ( ... DataCollection.range(1000).stream() ... .pmap(lambda x: stage_1_thread_set.add(threading.current_thread().ident), 5) ... .pmap(lambda x: stage_2_thread_set.add(threading.current_thread().ident), 4).to_list() ... ) >>> len(stage_1_thread_set) > 1 True >>> len(stage_2_thread_set) > 1 True """ backend = self.get_backend() if backend == 'ray': return self._ray_pmap(unary_op, num_worker) return self._thread_pmap(unary_op, num_worker)
def _thread_pmap(self, unary_op, num_worker=None): if num_worker is None and self.get_num_worker() is None: num_worker = 2 if num_worker is not None: executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_worker, initializer=initializer) elif self.get_executor() is not None: executor = self._executor num_worker = self._num_worker queue = Queue(num_worker) loop = asyncio.new_event_loop() def inner(): while True: x = queue.get() queue.task_done() if isinstance(x, EOS): break else: yield x async def worker(): buff = [] iterable = self._iterable.chunks() if isinstance(self._iterable, ChunkedTable) else self for x in iterable: if len(buff) == num_worker: queue.put(await buff.pop(0)) buff.append(loop.run_in_executor(executor, self._map_task(x, unary_op))) while len(buff) > 0: queue.put(await buff.pop(0)) queue.put(EOS()) def worker_wrapper(): loop.run_until_complete(worker()) loop.close() t = threading.Thread(target=worker_wrapper, daemon=True) t.start() res = inner() if isinstance(self._iterable, ChunkedTable): if not self.is_stream: res = list(res) res = ChunkedTable(chunks=res) return self._factory(res)
[docs] def mmap(self, ops: list, num_worker=None, backend=None): """ Apply multiple unary_op to data collection. Currently supports two backends, `ray` and `thread`. Args: unary_op (func): the op to be mapped; num_worker (int): how many threads to reserve for this op; backend (str): whether to use `ray` or `thread` # TODO: the test is broken with pytest # Examples: # 1. Using mmap: # >>> from towhee import DataCollection # >>> dc1 = DataCollection([0,1,2,'3',4]).stream() # >>> a1, b1 = dc1.mmap([lambda x: x+1, lambda x: x*2]) # >>> c1 = a1.map(lambda x: x+1) # >>> c1.zip(b1).to_list() # [(2, 0), (3, 2), (4, 4), (Empty(), '33'), (6, 8)] # 2. Using map instead of mmap: # >>> from towhee import DataCollection # >>> dc2 = DataCollection.range(5).stream() # >>> a2, b2, c2 = dc2.map(lambda x: x+1, lambda x: x*2, lambda x: int(x/2)) # >>> d2 = a2.map(lambda x: x+1) # >>> d2.zip(b2, c2).to_list() # [(2, 0, 0), (3, 2, 0), (4, 4, 1), (5, 6, 1), (6, 8, 2)] # 3. DAG execution: # >>> dc3 = DataCollection.range(5).stream() # >>> a3, b3, c3 = dc3.map(lambda x: x+1, lambda x: x*2, lambda x: int(x/2)) # >>> d3 = a3.map(lambda x: x+1) # >>> d3.zip(b3, c3).map(lambda x: x[0]+x[1]+x[2]).to_list() # [2, 5, 9, 12, 16] """ if len(ops) == 1: return self._pmap(unary_op=ops[0], num_worker=num_worker, backend=backend) next_vals = [] next_vals = self.split(len(ops)) ret = [] for i, x in enumerate(ops): ret.append(next_vals[i].pmap(x, num_worker=num_worker, backend=backend)) return ret
[docs]class EOS(): ''' Internal object used to signify end of processing queue. ''' pass