Source code for towhee.models.swin_transformer.model

# Copyright 2021 Microsoft . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.
import math
import torch
from torch import nn
from towhee.models.utils.weight_init import trunc_normal_
from towhee.models.layers.patch_embed2d import PatchEmbed2D
from towhee.models.layers.patch_merging import PatchMerging
from towhee.models.utils.init_vit_weights import init_vit_weights
from towhee.models.swin_transformer.basic_layer import BasicLayer


[docs]class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. """
[docs] def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, weight_init='', is_v2=False, pretrained_window_sizes=None): super().__init__() if pretrained_window_sizes is None: pretrained_window_sizes = [0, 0, 0, 0] self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed2D( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches self.patch_grid = self.patch_embed.grid_size # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) else: self.absolute_pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers layers = [] for i_layer in range(self.num_layers): layers += [BasicLayer( dim=int(embed_dim * 2 ** i_layer), input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, is_v2=is_v2, pretrained_window_size=pretrained_window_sizes[i_layer]) ] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. if weight_init.startswith('jax'): for n, m in self.named_modules(): init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) else: self.apply(init_vit_weights)
@torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def get_classifier(self): return self.head def reset_classifier(self, num_classes): self.num_classes = num_classes self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) if self.absolute_pos_embed is not None: x = x + self.absolute_pos_embed x = self.pos_drop(x) x = self.layers(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x
[docs] def forward(self, x): x = self.forward_features(x) x = self.head(x) return x
# if __name__ == '__main__': # from towhee.models.swin_transformer.configs import build_configs # from towhee.models.utils.pretrained_utils import load_pretrained_weights # img = torch.randn(1, 3, 224, 224) # name = 'swin_tiny_patch4_window7_224' # # arch, model_cfg = build_configs(name) # model = SwinTransformer(**arch) # # load_pretrained_weights(model, name, model_cfg) # out = model(img) # print(out.shape) # pass