Source code for towhee.models.multiscale_vision_transformers.create_mvit

# Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn

from towhee.models.multiscale_vision_transformers.mvit import MViT


[docs]def create_mvit_model( model_name: str = "imagenet_b_16_conv", checkpoint_path: str = None, device: str = None, change_model_keys: bool = True ) -> nn.Module: """ Create Multiscale Vision Transformers model. https://arxiv.org/abs/2104.11227 Args: model_name (`str`): Multiscale Vision Transformers model name. checkpoint_path (`str`): Local checkpoint path, default is None. Checkpoint weights can be download in https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md. device (`str`): Model device, cpu or cuda. change_model_keys (`bool`): This MViT structure is a little different from that from Facebookresearch for visualization. So you should set `change_model_keys` is True if you download pretrained checkpoint from Facebookresearch. """ def _change_model_keys(model_state): for key in list(model_state.keys()): if "attn.pool" in key or "attn.norm" in key: key_word_list = key.split(".") kqv = key_word_list[3][-1] norm_pool = key_word_list[3][:4] new_key = ".".join( [key_word_list[0], key_word_list[1], "attn", "atn_pool_" + kqv, norm_pool, key_word_list[-1]]) # print(f"old_key: {key}\t new_key: {new_key}") model_state[new_key] = model_state.pop(key) return model_state config = {} if model_name == "imagenet_b_16_conv": config = { "patch_2d": True, "patch_stride": [4, 4], "embed_dim": 96, "num_heads": 1, "mlp_ratio": 4.0, "qkv_bias": True, "dropout_rate": 0.0, "depth": 16, "droppath_rate": 0.1, "mode": "conv", "cls_embed_on": True, "sep_pos_embed": False, "norm": "layernorm", "patch_kernel": [7, 7], "patch_padding": [3, 3], "pool_q_kernel": [[], [1, 3, 3], [], [1, 3, 3], [], [], [], [], [], [], [], [], [], [], [1, 3, 3], []], "pool_kv_kernel": [[1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3]], "pool_skip_kernel": [[], [1, 3, 3], [], [1, 3, 3], [], [], [], [], [], [], [], [], [], [], [1, 3, 3], []], "pool_q_stride": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []], "pool_kv_stride": [[1, 4, 4], [1, 2, 2], [1, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], "pool_skip_stride": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []], "dim_mul_arg": [[1, 2.0], [3, 2.0], [14, 2.0]], "head_mul_arg": [[1, 2.0], [3, 2.0], [14, 2.0]], "norm_stem": False, "num_classes": 1000, "head_act": "softmax", "train_crop_size": 224, "test_crop_size": 224, "num_frames": 1, "input_channel_num": [3], } model = MViT(**config) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if checkpoint_path is not None: checkpoint = torch.load(checkpoint_path, map_location=device) model_state = checkpoint["model_state"] if change_model_keys: model_state = _change_model_keys(model_state) model.load_state_dict(model_state) return model