# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from enum import Enum, auto
from collections import defaultdict
from towhee.dag.operator_repr import OperatorRepr
from towhee.engine.operator_runner.runner_base import RunnerStatus
from towhee.dataframe import DataFrame
from towhee.engine.operator_io import create_reader, create_writer
from towhee.engine.operator_runner import create_runner
from towhee.engine.thread_pool_task_executor import ThreadPoolTaskExecutor
[docs]class OpStatus(Enum):
NOT_RUNNING = auto()
RUNNING = auto()
FINISHED = auto()
FAILED = auto()
STOPPED = auto()
[docs]class OperatorContext:
"""
The OperatorContext manages an operator's input data and output data at runtime,
as well as the operators' dependency within a GraphContext.
The abstraction of OperatorContext hides the complexity of Dataframe management,
input iteration, and data dependency between Operators. It offers a Task-based
scheduling context.
Args:
op_repr: (OperatorRepr)
The operator representation
dataframes: (`dict` of `DataFrame`)
All the `DataFrames` in `GraphContext`
"""
[docs] def __init__(self, op_repr: OperatorRepr, dataframes: Dict[str, DataFrame]):
self._repr = op_repr
self._readers = OperatorContext._create_reader(op_repr, dataframes)
self._writer = OperatorContext._create_writer(op_repr, dataframes)
self._op_runners = []
self._op_status = OpStatus.NOT_RUNNING
self._err_msg = None
@staticmethod
def _create_reader(op_repr, dataframes):
inputs_index = defaultdict(dict)
for item in op_repr.inputs:
inputs_index[item['df']][item['name']] = item['col']
iter_type = op_repr.iter_info['type']
iter_params = op_repr.iter_info.get('params')
inputs = dict((item['df'], dataframes[item['df']]) for item in op_repr.inputs)
readers = []
for df_name, indexs in inputs_index.items():
readers.append(create_reader(inputs[df_name], iter_type, indexs, iter_params))
return readers
@staticmethod
def _create_writer(op_repr, dataframes):
'''
Normally one op one output dataframe.
'''
outputs = list({dataframes[output['df']] for output in op_repr.outputs})
iter_type = op_repr.iter_info['type']
return create_writer(iter_type, outputs)
@property
def name(self):
return self._repr.name
@property
def err_msg(self):
return self._err_msg
@property
def status(self):
"""
Calc op-ctx status by checking all runners of this op-ctx
"""
if self._op_status in [OpStatus.FINISHED, OpStatus.FAILED]:
return self._op_status
if len(self._op_runners) == 0:
return self._op_status
finished_count = 0
for runner in self._op_runners:
if runner.status == RunnerStatus.FAILED:
self._op_status = OpStatus.FAILED
self._err_msg = runner.msg
else:
if runner.status == RunnerStatus.FINISHED:
finished_count += 1
if finished_count == len(self._op_runners):
self._op_status = OpStatus.FINISHED
return self._op_status
def start(self, executor: ThreadPoolTaskExecutor) -> None:
if self._op_status != OpStatus.NOT_RUNNING:
raise RuntimeError('OperatorContext can only be started once')
self._op_status = OpStatus.RUNNING
try:
for i in range(self._repr.threads):
self._op_runners.append(
create_runner(
self._repr.iter_info['type'],
self._repr.name,
i,
self._repr.name,
self._repr.tag,
self._repr.function,
self._repr.init_args,
self._readers,
self._writer,
)
)
except AttributeError as e:
self._err_msg = str(e)
self._op_status = OpStatus.FAILED
return
for runner in self._op_runners:
executor.push_task(runner)
def slow_down(self, time_sec: int):
if self.status == OpStatus.RUNNING:
for runner in self._op_runners:
runner.slow_down(time_sec)
def speed_up(self):
if self.status == OpStatus.RUNNING:
for runner in self._op_runners:
runner.speed_up()
def stop(self):
if self.status != OpStatus.RUNNING:
raise RuntimeError('Op ctx is already not running.')
for runner in self._op_runners:
runner.set_stop()
def join(self):
# Wait all runners finished.
for runner in self._op_runners:
runner.join()
self._writer.close()