Source code for towhee.serve.triton.to_triton_models

# Copyright 2021 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from pathlib import Path
import inspect
import pickle
import json
from abc import ABC
import logging

from towhee.serve.triton.triton_config_builder import TritonModelConfigBuilder, create_modelconfig, EnsembleConfigBuilder
from towhee.serve.triton.python_model_builder import gen_model_from_op, gen_model_from_pickled_callable
from towhee.serve.triton import constant

logger = logging.getLogger()

[docs]class TritonFiles: ''' File path name. '''
[docs] def __init__(self, root: str, model_name: str): self._root = Path(root) / model_name
@property def root(self) -> Path: return self._root @property def config_file(self) -> Path: return self._root / 'config.pbtxt' @property def model_path(self) -> Path: return self._root / '1' @property def dc_config_file(self) -> Path: return self.model_path / constant.DC_CONFIG_FILE @property def python_model_file(self) -> Path: return self.model_path / '' @property def trt_model_file(self) -> Path: return self.model_path / 'model.plan' @property def onnx_model_file(self) -> Path: return self.model_path / 'model.onnx' @property def preprocess_pickle(self) -> Path: return 'preprocess.pickle' @property def postprocess_pickle(self) -> Path: return 'postprocess.pickle' @property def postprocess_pickle_path(self) -> Path: return self.model_path / self.postprocess_pickle @property def preprocess_pickle_path(self) -> Path: return self.model_path / self.preprocess_pickle
[docs]class ToTriton(ABC): ''' ToTriton Base. '''
[docs] def __init__(self, obj: 'Operator', model_root: str, model_name: str, op_config: Dict): self._obj = obj self._model_name = model_name self._op_config = op_config self._triton_files = TritonFiles(model_root, self._model_name) self._inputs = TritonModelConfigBuilder.get_input_schema(self._obj.input_schema()) self._outputs = TritonModelConfigBuilder.get_output_schema(self._obj.output_schema()) self._backend = 'python'
@property def inputs(self) -> Dict: return self._inputs @property def outputs(self) -> Dict: return self._outputs def _create_model_dir(self) -> bool: self._triton_files.root.mkdir(parents=True, exist_ok=True) self._triton_files.model_path.mkdir(parents=True, exist_ok=True) return True def _prepare_model(self) -> bool: return True def _prepare_config(self) -> bool: device_ids = self._op_config.get('device_ids') instance_count = self._op_config.get('instance_count', 1) config_str = create_modelconfig( self._model_name, 0, self._inputs, self._outputs, self._backend, False, None, None, instance_count, device_ids ) with open(self._triton_files.config_file, 'wt', encoding='utf-8') as f: f.write(config_str) return True with open(self._triton_files.config_file, 'wt', encoding='utf-8') as f: f.write(config_str) return True def to_triton(self) -> bool: if self._create_model_dir() and self._prepare_model() and self._prepare_config(): return True return False
[docs]class PyOpToTriton(ToTriton): ''' PyOp to triton model. '''
[docs] def __init__(self, op, model_root, model_name, op_hub, op_name, init_args, op_config): super().__init__(op, model_root, model_name, op_config) self._op_hub = op_hub self._op_name = op_name self._init_args = init_args
def _prepare_model(self): with open(self._triton_files.dc_config_file, 'wt', encoding='utf-8') as f: json.dump(self._op_config, f) gen_model_from_op(self._triton_files.python_model_file, self._op_hub, self._op_name, self._init_args, self._obj.input_schema(), self._obj.output_schema() ) return True
[docs]class PreprocessToTriton(ToTriton): ''' Preprocess to triton model. '''
[docs] def __init__(self, op, model_root, model_name, op_config): super().__init__(op.preprocess, model_root, model_name, op_config) op_module_info = inspect.getmodule(op) self._init_file = Path(op_module_info.__file__).parent / '' self._module_name = '.'.join(op_module_info.__name__.split('.')[:-1])
def _prepare_model(self): gen_model_from_pickled_callable(str(self._triton_files.python_model_file), self._module_name, str(self._init_file), str(self._triton_files.preprocess_pickle), self._obj.input_schema(), self._obj.output_schema() ) # create pickle file with open(self._triton_files.preprocess_pickle_path, 'wb') as f: pickle.dump(self._obj, f) return True
[docs]class PostprocessToTriton(ToTriton): ''' Preprocess and Postprocess to triton model. '''
[docs] def __init__(self, op, model_root, model_name, op_config): super().__init__(op.postprocess, model_root, model_name, op_config) op_module_info = inspect.getmodule(op) self._init_file = Path(op_module_info.__file__).parent / '' self._module_name = '.'.join(op_module_info.__name__.split('.')[:-1])
def _prepare_model(self): gen_model_from_pickled_callable(str(self._triton_files.python_model_file), self._module_name, str(self._init_file), str(self._triton_files.postprocess_pickle), self._obj.input_schema(), self._obj.output_schema() ) # create pickle file with open(self._triton_files.postprocess_pickle_path, 'wb') as f: pickle.dump(self._obj, f) return True
[docs]class ModelToTriton (ToTriton): ''' NNOp to triton model. Convert model to trt, torchscript or onnx. '''
[docs] def __init__(self, op, model_root, model_name, op_config): super().__init__(op.model, model_root, model_name, op_config) self._model_format_priority = op_config.get(constant.FORMAT_PRIORITY, [])
def _prepare_config(self) -> bool: ''' All model open dynamic_batching. After towhee.dc supports config the batch info, we use the config. example of input and output: { INPUT0': ('TYPE_INT8', [-1, -1, 3]), INPUT1': ('TYPE_FP32', [-1, -1, 3]) } ''' dynamic_batching = self._op_config.get('dynamic_batching', {}) device_ids = self._op_config.get('device_ids') instance_count = self._op_config.get('instance_count', 1) if dynamic_batching: # remove the batch dim inputs = {} for k, v in self._inputs.items(): inputs[k] = (v[0], v[1][1:]) outputs = {} for k, v in self._outputs.items(): outputs[k] = (v[0], v[1][1:]) max_batch_size = dynamic_batching.get('max_batch_size', 1) enable_dynamic_batching = True preferred_batch_size = dynamic_batching.get('preferred_batch_size', None) preferred_max_queue_delay_microseconds = dynamic_batching.get('preferred_max_queue_delay_microseconds', None) config_str = create_modelconfig( self._model_name, max_batch_size, inputs, outputs, self._backend, enable_dynamic_batching, preferred_batch_size, preferred_max_queue_delay_microseconds, instance_count, device_ids ) else: config_str = create_modelconfig( self._model_name, 0, self._inputs, self._outputs, self._backend, False, None, None, instance_count, device_ids ) with open(self._triton_files.config_file, 'wt', encoding='utf-8') as f: f.write(config_str) return True def _prepare_model(self): succ = False for fmt in self._model_format_priority: if fmt in self._obj.supported_formats: if fmt == 'onnx': succ = self._obj.save_model(fmt, self._triton_files.onnx_model_file) self._backend = 'onnxruntime' elif fmt == 'tensorrt': succ = self._obj.save_model(fmt, self._triton_files.trt_model_file) self._backend = 'tensorrt' else: logger.error('Unkown optimize %s', fmt) continue return succ
[docs]class EnsembleToTriton: ''' Create triton ensemble config '''
[docs] def __init__(self, dag, model_root, model_name, batch_size): self._dag = dag self._model_name = model_name self._triton_files = TritonFiles(model_root, self._model_name) self._batch_size = batch_size
def to_triton(self): self._triton_files.root.mkdir(parents=True, exist_ok=True) self._triton_files.model_path.mkdir(parents=True, exist_ok=True) config_str = EnsembleConfigBuilder(self._dag, self._model_name, self._batch_size).gen_config() with open(self._triton_files.config_file, 'wt', encoding='utf-8') as f: f.write(config_str) return True