Source code for towhee.serve.triton.bls.model_runner.generator_model

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import queue
from pathlib import Path
import threading

from towhee import ops
from towhee.serve.triton.bls.model_runner.transformer import RequestToOpInputs, OpOutputToResponses
from towhee.serve.triton import constant
from towhee.serve.triton.bls.utils.op_config import OpConfig
from towhee.serve.triton.bls.python_backend_wrapper import pb_utils

logger = logging.getLogger()


[docs]class TritonPythonModel: ''' Supports running towhee generator operators. ''' def initialize(self, args): device = None if args['model_instance_kind'] == 'GPU': device = int(args['model_instance_device_id']) self._op_config = OpConfig.load_from_file(self.op_config_file) if self._op_config is None: err = 'Load operator config file [%s] failed' % self.op_config_file raise IOError(err) hub = getattr(ops, self._op_config.op_hub) op_wrapper = getattr(hub, self._op_config.op_name)(**self._op_config.init_args) self._op = op_wrapper.get_op() if hasattr(self._op, 'to_device') and device is not None: self._op._device = device # pylint: disable=protected-access self._op.to_device() self._queue = queue.Queue() self._need_stop = False self._processor = threading.Thread(target=self.worker) self._processor.start() def worker(self): while not self._need_stop: request = self._queue.get() if isinstance(request, queue.Empty): break response_sender = request.get_response_sender() inputs = RequestToOpInputs(request, self._op.input_schema()).get_towhee_data() generator = self._op(*inputs) for ret in generator: if not isinstance(ret, tuple): ret = (ret, ) response = OpOutputToResponses(ret).to_triton_responses() response_sender.send(response) response_sender.send( flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) @property def op_config_file(self) -> str: if hasattr(self, '_op_config_file'): return self._op_config_file return str(Path(__file__).parent.resolve() / constant.OP_CONFIG_FILE) def execute(self, requests): for request in requests: self._queue.put(request) def finalize(self): self._need_stop = True self._queue.put(queue.Empty()) self._processor.join()