Source code for towhee.models.vit.vit_block

# Original pytorch implementation by:
# 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
#       - https://arxiv.org/abs/2010.11929
# 'How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers'
#       - https://arxiv.org/abs/2106.10270
#
# Original code by / Copyright 2020, Ross Wightman.
# Modifications & additions by / Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch import nn
from towhee.models.layers.attention import MultiHeadAttention
from towhee.models.layers.droppath import DropPath
from towhee.models.layers.mlp import Mlp
from towhee.models.layers.layers_with_relprop import GELU, LayerNorm, Add, Clone


[docs]class Block(nn.Module): """ The Transformer block. Args: dim (int): number of features num_heads (int): number of heads mlp_ratio (int): ratio of mlp layer qkv_bias (bool): if add bias to qkv layer qk_scale (float): number to scale qk drop_ratio (float): drop rate at the end of the block (mlp layer) attn_drop_ratio (float): drop rate of attention layer drop_path_ratio (float): drop rate of drop_path layer act_layer: activation layer norm_layer: normalization layer """
[docs] def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_ratio=0, attn_drop_ratio=0, drop_path_ratio=0, act_layer=GELU, norm_layer=LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = MultiHeadAttention( dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio ) self.drop_path = DropPath(drop_prob=drop_path_ratio) if drop_path_ratio > 0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop_ratio) self.add1 = Add() self.add2 = Add() self.clone1 = Clone() self.clone2 = Clone()
[docs] def forward(self, x): x1, x2 = self.clone1(x, 2) x = self.add1([x1, self.attn(self.norm1(x2))]) x1, x2 = self.clone2(x, 2) x = self.add2([x1, self.mlp(self.norm2(x2))]) return x
def relprop(self, cam, **kwargs): (cam1, cam2) = self.add2.relprop(cam, **kwargs) cam2 = self.mlp.relprop(cam2, **kwargs) cam2 = self.norm2.relprop(cam2, **kwargs) cam = self.clone2.relprop((cam1, cam2), **kwargs) (cam1, cam2) = self.add1.relprop(cam, **kwargs) cam2 = self.attn.relprop(cam2, **kwargs) cam2 = self.norm1.relprop(cam2, **kwargs) cam = self.clone1.relprop((cam1, cam2), **kwargs) return cam
# if __name__=='__main__': # import torch # x = torch.rand(1, 197, 768) # model = Block(dim=768, num_heads=8) # out = model.forward(x) # print(out.shape)