Source code for tools.visualizer

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Union, List, Dict, Any

from towhee.runtime.time_profiler import TimeProfiler
from towhee.tools.profilers import PerformanceProfiler
from towhee.datacollection import DataCollection
from towhee.utils.lazy_import import LazyImport
from towhee.utils.log import engine_log
from towhee.utils.serializer import TritonSerializer, TritonParser


graph_visualizer = LazyImport('graph_visualizer', globals(), 'towhee.tools.graph_visualizer')
data_visualizer = LazyImport('data_visualizer', globals(), 'towhee.tools.data_visualizer')


def show_graph(pipeline):
    gv = graph_visualizer.GraphVisualizer(pipeline.dag_repr)
    gv.show()


[docs]class Visualizer: """ Visualize the debug information. """
[docs] def __init__( self, result: Union['DataQueue', List[Any]]=None, time_profiler: List[Any]=None, data_queues: List[Dict[str, Any]]=None, nodes: Dict[str, Any]=None, trace_nodes: List[str]=None, node_collection: List[Dict[str, Any]]=None ): self._result = result self._time_profiler = time_profiler self._data_queues = data_queues self._trace_nodes = trace_nodes self._nodes = nodes self._node_collection = node_collection if node_collection \ else [self._get_collection(i) for i in self._get_node_queues()] if self._data_queues \ else None self._profiler = None self._tracer = None
def _get_node_queues(self): """ Get node queue with given graph data queues. """ node_queues = [] for data_queue in self._data_queues: node_queue = {} for node in self._nodes.values(): if node['name'] not in self._trace_nodes: continue node_queue[node['name']] = {} node_queue[node['name']]['type'] = node['iter_info']['type'] node_queue[node['name']]['operator'] = node['op_info']['operator'] node_queue[node['name']]['in'] = [data_queue[edge] for edge in node['inputs']] node_queue[node['name']]['out'] = [data_queue[edge] for edge in node['outputs']] node_queue[node['name']]['op_input'] = node['op_input'] node_queue[node['name']]['op_output'] = node['op_output'] node_queue[node['name']]['next'] = [self._nodes[i]['name'] for i in node['next_nodes']] self._set_previous(node_queue) node_queues.append(node_queue) return node_queues def _set_previous(self, node_queue): for node in self._nodes.values(): for i in node['next_nodes']: next_node = self._nodes[i]['name'] if next_node not in self._trace_nodes: continue if 'previous' not in node_queue[next_node]: node_queue[next_node]['previous'] = [node['name']] else: node_queue[next_node]['previous'].append(node['name']) @staticmethod def _get_collection(node_info): def _to_collection(x): for idx, q in enumerate(x): if not q.size: q.reset_size() tmp = DataCollection(q) q.reset_size() x[idx] = tmp for v in node_info.values(): _to_collection(v['in']) _to_collection(v['out']) return node_info @property def result(self): return self._result @property def profiler(self): if not self._time_profiler: w_msg = 'Please set `profiler` to `True` when debug, there is nothing to report.' engine_log.warning(w_msg) return None if not self._profiler: self._profiler = PerformanceProfiler(self._time_profiler, self._nodes) return self._profiler @property def tracer(self): if not self._node_collection: w_msg = 'Please set `tracer` to `True` when debug, there is nothing to report.' engine_log.warning(w_msg) return None if not self._tracer: self._tracer = data_visualizer.DataVisualizer(self._nodes, self._node_collection) return self._tracer @property def time_profiler(self): return self._time_profiler @property def node_collection(self): return self._node_collection @property def nodes(self): return self._nodes def _collcetion_to_dict(self): #pylint: disable=not-an-iterable for info in self._node_collection: for node in info.values(): node['in'] = [i.to_dict() for i in node['in']] node['out'] = [i.to_dict() for i in node['out']] @staticmethod def _dict_to_collcetion(data): for info in data['node_collection']: for node in info.values(): node['in'] = [DataCollection.from_dict(i) for i in node['in']] node['out'] = [DataCollection.from_dict(i) for i in node['out']] def _to_dict(self): info = {} if self._result: info['result'] = self._result if self._time_profiler: info['time_record'] = [i.time_record for i in self._time_profiler] if self._nodes: info['nodes'] = self._nodes if self._trace_nodes: info['trace_nodes'] = self._trace_nodes if self._data_queues: self._collcetion_to_dict() info['node_collection'] = self._node_collection return info
[docs] def to_json(self, **kws): return json.dumps(self._to_dict(), cls=TritonSerializer, **kws)
[docs] @staticmethod def from_json(info): info_dict = json.loads(info, cls=TritonParser) Visualizer._dict_to_collcetion(info_dict) return Visualizer( result=info_dict.get('result'), time_profiler=[TimeProfiler(enable=True, time_record=i) for i in info_dict.get('time_record')], nodes=info_dict.get('nodes'), trace_nodes=info_dict.get('trace_nodes'), node_collection=info_dict.get('node_collection') )