Source code for towhee.models.hornet.configs

# Configuration & model weights are from https://github.com/raoyongming/HorNet
#
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from towhee.models.hornet import GatedConv, Block, GlobalLocalFilter


[docs]def get_configs(model_name): configs = { 'in_chans': 3, 'num_classes': 1000, 'depths': (2, 3, 18, 2), 'base_dim': 128, 'drop_path_rate': 0., 'layer_scale_init_value': 1e-6, 'head_init_scale': 1., 'gnconv': [ partial(GatedConv, order=2, s=1.0 / 3.0), partial(GatedConv, order=3, s=1.0 / 3.0), partial(GatedConv, order=4, s=1.0 / 3.0), partial(GatedConv, order=5, s=1.0 / 3.0), ], 'block': Block, 'uniform_init': False } gf_gnconv = [ partial(GatedConv, order=2, s=1.0 / 3.0), partial(GatedConv, order=3, s=1.0 / 3.0), partial(GatedConv, order=4, s=1.0 / 3.0, h=14, w=8, gflayer=GlobalLocalFilter), partial(GatedConv, order=5, s=1.0 / 3.0, h=7, w=4, gflayer=GlobalLocalFilter), ] if model_name is None: pass elif model_name == 'hornet_tiny_7x7': configs.update( url='https://cloud.tsinghua.edu.cn/f/1ca970586c6043709a3f/?dl=1', base_dim=64 ) elif model_name == 'hornet_tiny_gf': configs.update( url='https://cloud.tsinghua.edu.cn/f/511faad0bde94dfcaa54/?dl=1', base_dim=64, gnconv=gf_gnconv ) elif model_name == 'hornet_small_7x7': configs.update( url='https://cloud.tsinghua.edu.cn/f/46422799db2941f7b684/?dl=1', base_dim=96 ) elif model_name == 'hornet_small_gf': configs.update( url='https://cloud.tsinghua.edu.cn/f/8405c984bf084d2ba85a/?dl=1', base_dim=96, gnconv=gf_gnconv ) elif model_name == 'hornet_base_7x7': configs.update( url='https://cloud.tsinghua.edu.cn/f/5c86cb3d655d4c17a959/?dl=1' ) elif model_name == 'hornet_base_gf': configs.update( url='https://cloud.tsinghua.edu.cn/f/6c84935e63b547f383fb/?dl=1', gnconv=gf_gnconv ) elif model_name == 'hornet_base_gf_img384': configs.update( gnconv=[ partial(GatedConv, order=2, s=1.0 / 3.0), partial(GatedConv, order=3, s=1.0 / 3.0), partial(GatedConv, order=4, s=1.0 / 3.0, h=24, w=13, gflayer=GlobalLocalFilter), partial(GatedConv, order=5, s=1.0 / 3.0, h=12, w=7, gflayer=GlobalLocalFilter), ] ) elif model_name == 'hornet_large_7x7_22k': configs.update( url='https://cloud.tsinghua.edu.cn/f/4de41e26cb254c28a61a/?dl=1', num_classes=21841, base_dim=192 ) elif model_name == 'hornet_large_gf_22k': configs.update( url='https://cloud.tsinghua.edu.cn/f/8679b6acf63c41e285d9/?dl=1', num_classes=21841, base_dim=192, gnconv=gf_gnconv ) elif model_name == 'hornet_large_gf_img384_22k': configs.update( url='https://cloud.tsinghua.edu.cn/f/f36957d46eef47da9c25/?dl=1', num_classes=21841, base_dim=192, gnconv=[ partial(GatedConv, order=2, s=1.0 / 3.0), partial(GatedConv, order=3, s=1.0 / 3.0), partial(GatedConv, order=4, s=1.0 / 3.0, h=24, w=13, gflayer=GlobalLocalFilter), partial(GatedConv, order=5, s=1.0 / 3.0, h=12, w=7, gflayer=GlobalLocalFilter), ] ) else: raise ValueError(f'Invalid model name: {model_name}.') return configs