Source code for towhee.functional.mixins.dag

from functools import wraps
from uuid import uuid4
from abc import ABCMeta

from towhee.hparam import param_scope

[docs]def register_dag(f): @wraps(f) def wrapper(self, *arg, **kws): # Get the result DataCollections children = f(self, *arg, **kws) # Need the dc type while avoiding circular imports dc_type = type(children[0]) if isinstance(children, list) else type(children) # If the function is called from an existing dc index_info = param_scope()._index # pylint: disable=protected-access if index_info is None: input_info = None output_info = None elif isinstance(index_info, tuple): input_info = list(index_info[0]) if isinstance(index_info[0], tuple) else [index_info[0]] output_info = list(index_info[1]) if isinstance(index_info[1], tuple) else [index_info[1]] else: input_info = None output_info = list(index_info) if isinstance(index_info, tuple) else [index_info] if isinstance(self, dc_type): self.op = f.__name__ self.call_args = {'*arg': arg, '*kws': kws} if arg != tuple(): if hasattr(arg[0], '_op_config'): self.op_config = arg[0].op_config # check if list of dc or just dc if isinstance(children, dc_type): self.child_ids = [children.id] else: self.child_ids = [x.id for x in children] info = {'op': self.op, 'op_name': self.op_name, 'is_stream': self.is_stream, 'init_args': self.init_args, 'call_args': self.call_args, 'op_config': self.get_pipeline_config(), 'input_info': input_info, 'output_info': output_info, 'parent_ids': self.parent_ids, 'child_ids': self.child_ids} # if has op_config, update op_config if self.op_config is not None: info['op_config'].update(self.op_config) self.get_control_plane().dag[self.id] = info return children # If not called from a dc, think static or class method. else: op = f.__name__ # if the method is being called from a classmethod, avoiding passing in self if isinstance(self, (dc_type, ABCMeta)): pass_args = arg else: pass_args = (self,) + arg call_args = {'*arg': pass_args, '*kws': kws} if isinstance(children, dc_type): child_ids = [children.id] else: child_ids = [x.id for x in children] info = {'op': op, 'op_name': None, 'is_stream': None, 'init_args': None, 'call_args': call_args, 'op_config': None, 'input_info': input_info, 'output_info': output_info, 'parent_ids': [], 'child_ids': child_ids} # If not called from a dc, it means that it is a start method # so it must be added to the childrens dags. for x in children if isinstance(children, list) else [children]: x.get_control_plane().dag['start'] = info return children return wrapper
[docs]class DagMixin: #pylint: disable=import-outside-toplevel """ Mixin for creating DAGs and their corresponding yamls from a DC """
[docs] def __init__(self) -> None: super().__init__() # Unique id for current dc self.id = str(uuid4().hex[:8]) with param_scope() as hp: parent = hp().data_collection.parent(None) if parent is None: self.parent_ids = ['start'] self._control_plane = ControlPlane() else: self.parent_ids = [parent.id] self._control_plane = parent._control_plane self.op = None self.op_name = None self.init_args = None self.call_args = None self.op_config = None self.input_info = None self.output_info = None self.child_ids = []
def register_dag(self, children): # check if list of dc or just dc if isinstance(children, type(self)): self.child_ids = [children.id] else: self.child_ids = [x.id for x in children] info = {'op': self.op, 'op_name': self.op_name, 'is_stream': self.is_stream, 'init_args': self.init_args, 'call_args': self.call_args, 'op_config': self.op_config, 'input_info': self.input_info, 'output_info': self.output_info, 'parent_ids': self.parent_ids, 'child_ids': self.child_ids} self._control_plane.dag[self.id] = info return children def notify_consumed(self, new_id): info = {'op': 'nop', 'op_name': None, 'init_args': None, 'call_args': None, 'op_config': None, 'input_info': None, 'output_info': None, 'parent_ids': self.parent_ids, 'child_ids': [new_id]} self._control_plane.dag[self.id] = info def compile_dag(self): info = {'op': 'nop', 'op_name': None, 'init_args': None, 'call_args': None, 'op_config': None, 'input_info': None, 'output_info': None, 'parent_ids': self.parent_ids, 'child_ids': ['end']} self._control_plane.dag[self.id] = info info = {'op': 'end', 'op_name': None, 'init_args': None, 'call_args': None, 'op_config': None, 'input_info': None, 'output_info': None, 'parent_ids': [self.id], 'child_ids': []} self._control_plane.dag['end'] = info # return self._control_plane.dag return self._clean_nops(self._control_plane.dag) def netx(self): import networkx as nx import matplotlib.pyplot as plt compiled_dag = self.compile_dag() new_dict = {} label_dict = {} for key,value in compiled_dag.items(): # print(key, value) new_dict[key] = value['child_ids'] label_dict[key] = value['op'] label_dict['end'] = 'end' g = nx.DiGraph(new_dict) pos = nx.nx_pydot.graphviz_layout(g, 'dot') nx.draw_networkx(g, pos, labels=label_dict, with_labels = True) plt.show() def _clean_nops(self, dag): dag_copy = dag.copy() removals = [] for key, val in dag.items(): if val['op'] == 'nop': removals.append(key) for parent in val['parent_ids']: dag_copy[parent]['child_ids'].remove(key) dag_copy[parent]['child_ids'] = list(set(dag_copy[parent]['child_ids'] + val['child_ids'])) for child in val['child_ids']: dag_copy[child]['parent_ids'].remove(key) dag_copy[child]['parent_ids'] = list(set(dag_copy[child]['parent_ids'] +val['parent_ids'])) for x in removals: del dag_copy[x] return self._add_op_name_and_init_args(dag_copy) def _add_op_name_and_init_args(self, dag): for key, val in dag.items(): if (val['op'] == 'map' or val['op'] == 'filter') or val['call_args'] is not None: if val['call_args']['*arg'] != () and hasattr(val['call_args']['*arg'][0], '_kws'): dag[key]['init_args'] = val['call_args']['*arg'][0].init_args if len(val['call_args']['*arg'][0].function.split('/')) > 1: dag[key]['op_name'] = val['call_args']['*arg'][0].function else: dag[key]['op_name'] = 'towhee/' + val['call_args']['*arg'][0].function else: dag[key]['op_name'] = 'dummy_input' dag[key]['parent_ids'] = [] start = key else: dag[key]['op_name'] = 'end' dag['start'] = dag[start] del dag[start] return self._solve_child_ids(dag, start) def _solve_child_ids(self, dag, start): co_dag = dag.copy() for k, v in dag.items(): if start in v['parent_ids']: co_dag[k]['parent_ids'].append('start') co_dag[k]['parent_ids'].remove(start) return self._solve_input_info(co_dag) def _solve_input_info(self, dag): copy_dag = dag.copy() for key, value in dag.items(): if isinstance(value['input_info'], list): input_info = [] for i in value['input_info']: for j in value['parent_ids']: if j == 'start': input_info.append(tuple(['start', i])) elif dag[j]['output_info'] is not None and (i in dag[j]['output_info']) : input_info.append(tuple([j, i])) copy_dag[key]['input_info'] = input_info return copy_dag def _clean_streams(self, dag): raise NotImplementedError def get_control_plane(self): return self._control_plane
[docs]class ControlPlane:
[docs] def __init__(self) -> None: self._dag = {}
@property def dag(self): return self._dag