# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import Union
from pathlib import Path
from importlib import import_module
from towhee.hub.repo_manager import RepoManager
from towhee.utils.log import engine_log
from towhee.utils.yaml_utils import load_yaml, dump_yaml
[docs]class OperatorManager(RepoManager):
"""
The Repo Manager to manage the operator repos.
Args:
author (`str`):
The author of the repo.
repo (`str`):
The name of the repo.
root (`str`):
The root url where the repo located.
"""
[docs] def __init__(self, author: str, repo: str, root: str = 'https://hub.towhee.io'):
super().__init__(author, repo, root)
# 2 represents operators when creating a repo in Towhee's hub
self._class = 2
[docs] def create(self, password: str) -> None:
"""
Create a repo under current account.
Args:
password (`str`):
Current author's password.
Raises:
(`HTTPError`)
Raise error in request.
"""
if self.exists():
engine_log.info('%s/%s repo already exists.', self._author, self._repo)
else:
self.hub_utils.create(password, self._class)
[docs] def create_with_token(self, token: str) -> None:
"""
Create a repo under current account.
Args:
token (`str`):
Current author's token.
Raises:
(`HTTPError`)
Raise error in request.
"""
if self.exists():
engine_log.info('%s/%s repo already exists.', self._author, self._repo)
else:
self.hub_utils.create_repo(token, self._class)
[docs] def init_nnoperator(self, file_temp: Union[str, Path], file_dest: Union[str, Path], framework: str = 'pytorch') -> None:
"""
Initialize the files under file_dest by moving and updating the text under file_temp.
Args:
file_temp (`Union[str, Path]`):
The path to the template files.
file_dest (`Union[str, Path]`):
The path to the local repo to init.
framework (`str, Path`):
The framework for the model, defaults to 'pytorch'
Raises:
(`HTTPError`)
Raise error in request.
(`OSError`)
Raise error in writing file.
"""
repo_temp = self._temp['nnoperator']
ori_str_list = [f'author/{repo_temp}', repo_temp, ''.join(x.title() for x in repo_temp.split('-')), 'pytorch']
tar_str_list = [f'{self._author}/{self._repo}', self._repo, ''.join(x.title() for x in self._repo.split('-')), framework]
for file in Path(file_temp).glob('*'):
if file.name.endswith(('.md', '.yaml', 'template.py', '__init__.py')):
new_file = Path(file_dest) / str(file.name).replace(repo_temp.replace('-', '_'), self._repo.replace('-', '_'))
self.hub_utils.update_text(ori_str_list, tar_str_list, str(file), str(new_file))
elif file.name != '.git':
os.rename(file, Path(file_dest) / file.name)
if framework != 'pytorch':
os.rename(Path(file_dest) / 'pytorch', Path(file_dest) / framework)
[docs] def init_pyoperator(self, file_temp: Union[str, Path], file_dest: Union[str, Path]) -> None:
"""
Initialize the files under file_dest by moving and updating the text under file_temp.
Args:
file_temp (`Union[str, Path]`):
The path to the template files.
file_dest (`Union[str, Path]`):
The path to the local repo to init.
Raises:
(`HTTPError`)
Raise error in request.
(`OSError`)
Raise error in writing file.
"""
repo_temp = self._temp['pyoperator']
temp_module = repo_temp.replace('-', '_')
repo_module = self._repo.replace('-', '_')
ori_str_list = [f'namespace.{temp_module}', temp_module, repo_temp, ''.join(x.title() for x in repo_temp.split('-'))]
tar_str_list = [f'{self._author}.{repo_module}', repo_module, self._repo, ''.join(x.title() for x in self._repo.split('-'))]
for file in Path(file_temp).glob('*'):
if file.name.endswith(('.md', '.yaml', 'template.py', '__init__.py')):
new_file = Path(file_dest) / str(file.name).replace(temp_module, repo_module)
self.hub_utils.update_text(ori_str_list, tar_str_list, str(file), str(new_file))
elif file.name != '.git':
os.rename(file, Path(file_dest) / file.name)
[docs] def generate_yaml(self, local_dir: Union[str, Path] = Path.cwd()) -> None:
"""
Generate the yaml of Operator.
Args:
local_dir (`Union[str, Path]`):
The directory to the repo.
"""
sys.path.append(str(local_dir))
yaml_file = Path(local_dir) / (self._repo.replace('-', '_') + '.yaml')
if yaml_file.exists():
engine_log.error('There already exists %s.', yaml_file)
return
class_name = ''.join(x.title() for x in self._repo.split('-'))
author_operator = self._author + '/' + self._repo
# import the class from repo
cls = getattr(import_module(self._repo.replace('-', '_')), class_name)
init_args = cls.__init__.__annotations__
try:
del init_args['return']
except KeyError:
pass
call_func = cls.__call__.__annotations__
try:
call_output = call_func.pop('return')
call_output = call_output.__annotations__
except KeyError:
pass
data = {
'name': self._repo,
'labels': {
'recommended_framework': '', 'class': '', 'others': ''
},
'operator': author_operator,
'init': self.hub_utils.convert_dict(init_args),
'call': {
'input': self.hub_utils.convert_dict(call_func), 'output': self.hub_utils.convert_dict(call_output)
}
}
with open(yaml_file, 'a', encoding='utf-8') as outfile:
dump_yaml(data, outfile)
[docs] def check(self, local_dir: Union[str, Path] = Path().cwd()) -> bool:
"""
Check if the main file exists and match the file name.
Args:
local_dir (`Union[str, Path]`):
The directory to the repo.
Returns:
(`bool`)
Check if passed.
"""
# check if the main file exists and match the file name
file_name = self._repo.replace('-', '_')
for file in [f'{file_name}.py', f'{file_name}.yaml']:
if not (Path(local_dir) / file).exists():
return False
return self.check_yaml(local_dir)
[docs] def check_yaml(self, local_dir: Union[str, Path] = Path().cwd()) -> bool:
"""
Check if the yaml file matches the format.
Args:
local_dir (`Union[str, Path]`):
The directory to the repo.
Returns:
(`bool`)
Check if passed.
"""
try:
yaml_file = Path(local_dir) / (self._repo.replace('-', '_') + '.yaml')
with open(yaml_file, 'r', encoding='utf-8') as input_file:
dicts = load_yaml(input_file)
if 'init' in dicts.keys() and dicts['call']['input'] is not None and dicts['call']['output'] is not None:
return True
except KeyError:
return False