Source code for towhee.engine.graph_context

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import Tuple

from towhee.dataframe import DataFrame
from towhee.dag import GraphRepr
from towhee.dataframe.iterators import MapIterator

from towhee.engine.operator_context import OperatorContext, OpStatus
from towhee.utils.log import engine_log
from towhee.errors import OpFailedError


[docs]class GraphContext: """ `GraphContext` is a data processing network with one or multiple `Operator`. Each row of a `Pipeline`'s inputs will be processed individually by a `GraphContext`. Args: ctx_idx: (`int`) The index of this `GraphContext`. graph_repr: (`towhee.dag.GraphRepr`) The DAG representation this `GraphContext` will implement. """
[docs] def __init__(self, ctx_idx: int, graph_repr: GraphRepr): self._idx = ctx_idx self._repr = graph_repr self._lock = threading.Lock() self._build_components()
[docs] def __call__(self, inputs: Tuple): self.inputs.put(inputs) self.inputs.seal()
@property def inputs(self) -> DataFrame: """Returns the graph's input `DataFrame`. Returns: (`towhee.dataframe.DataFrame`) The input `DataFrame`. """ return self.dataframes['_start_df'] @property def outputs(self) -> DataFrame: """Returns the graph's output `DataFrame`. Returns: (`towhee.dataframe.DataFrame`) The output `DataFrame`. """ return self.dataframes['_end_df'] def result(self) -> any: if self.outputs.size != 0: return self._out_iter else: # graph run failed, raise an exception for op in self._op_ctxs.values(): if op.status == OpStatus.FAILED: raise OpFailedError(op.err_msg) engine_log.warning('The pipeline runs successfully, but no data return') return None @property def op_ctxs(self): return self._op_ctxs @property def dataframes(self): return self._dataframes
[docs] def slow_down(self, df_name: str, time_sec: int): ''' Slow down the op whose df name it. ''' self._df_op[df_name].slow_down(time_sec)
[docs] def speed_up(self, df_name: str): ''' spped up the op whose df name it. ''' self._df_op[df_name].speed_up()
def stop(self): for op in self._op_ctxs: op.stop() def gc(self): for _, df in self._dataframes.items(): df.gc() def join(self): for op in self._op_ctxs.values(): op.join() def _build_components(self): """ Builds `DataFrame`s and `OperatorContext`s required to run this graph. """ # Build dataframes. dfs = {} for df_name, df_repr in self._repr.dataframes.items(): cols = [(col.name, col.vtype) for col in df_repr.columns] dfs[df_name] = DataFrame(df_name, cols) self._dataframes = dfs # Build operator contexts. self._op_ctxs = {} self._df_op = {} for _, op_repr in self._repr.operators.items(): op_ctx = OperatorContext(op_repr, self.dataframes) self._df_op[op_repr.outputs[0]['df']] = op_ctx self._op_ctxs[op_ctx.name] = op_ctx self._out_iter = MapIterator(self.outputs) def __del__(self): engine_log.info('Graph % s end', self._idx)