Source code for towhee.models.layers.multi_scale_transformer_block

# Copyright 2021  Facebook. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.

import torch
from torch import nn
from typing import List, Tuple

from towhee.models.layers.multi_scale_attention import MultiScaleAttention
from towhee.models.layers.droppath import DropPath
from towhee.models.layers.mlp import Mlp
from towhee.models.layers.pool_attention import AttentionPool

[docs]class MultiScaleBlock(nn.Module): """ A multiscale vision transformer block. Each block contains a multiscale attention layer and a Mlp layer. :: Input |-------------------+ ↓ | Norm | ↓ | MultiScaleAttention Pool ↓ | DropPath | ↓ | Summation ←-------------+ | |-------------------+ ↓ | Norm | ↓ | Mlp Proj ↓ | DropPath | ↓ | Summation ←------------+ Args: dim(int): Input feature dimension. dim_out(int): Output feature dimension. num_heads(int): Number of heads in the attention layer. mlp_ratio(float): MLP ratio which controls the feature dimension in the hidden layer of the MLP block. qkv_bias(bool): If set to False, the qkv layer will not learn an additive bias. dropout_rate(float): DropOut rate. If set to 0, DropOut is disabled. droppath_rate(float): DropPath rate. If set to 0, DropPath is disabled. activation(nn.Module): Activation layer used in the MLP layer. norm_layer(nn.Module): Normalization layer. kernel_q(_size_3_t): Pooling kernel size for q. If pooling kernel size is 1 for all the dimensions. kernel_kv(_size_3_t): Pooling kernel size for kv. If pooling kernel size is 1 for all the dimensions, pooling is not used. stride_q(_size_3_t): Pooling kernel stride for q. stride_kv(_size_3_t): Pooling kernel stride for kv. pool_mode(nn.Module): Pooling mode. has_cls_embed(bool): If set to True, the first token of the input tensor should be a cls token. Otherwise, the input tensor does not contain a cls token. Pooling is not applied to the cls token. pool_first(bool): If set to True, pool is applied before qkv projection. Otherwise, pool is applied after qkv projection. """
[docs] def __init__( self, dim, dim_out, num_heads, mlp_ratio=4.0, qkv_bias=False, dropout_rate=0.0, droppath_rate=0.0, activation=nn.GELU, norm_layer=nn.LayerNorm, kernel_q=(1, 1, 1), kernel_kv=(1, 1, 1), stride_q=(1, 1, 1), stride_kv=(1, 1, 1), pool_mode=nn.Conv3d, has_cls_embed=True, pool_first=False, ) -> None: super().__init__() self.dim = dim self.dim_out = dim_out self.norm1 = norm_layer(dim) kernel_skip = [s + 1 if s > 1 else s for s in stride_q] stride_skip = stride_q padding_skip = [int(skip // 2) for skip in kernel_skip] self.attn = MultiScaleAttention( dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, dropout_rate=dropout_rate, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv, norm_layer=nn.LayerNorm, has_cls_embed=has_cls_embed, pool_mode=pool_mode, pool_first=pool_first, ) self.drop_path = ( DropPath(drop_prob=droppath_rate) if droppath_rate > 0.0 else nn.Identity() ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.has_cls_embed = has_cls_embed self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim_out, act_layer=activation, drop=dropout_rate, ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) self.pool_skip = ( nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) if len(kernel_skip) > 0 else None )
[docs] def forward( self, x: torch.Tensor, thw_shape: List[int] ) -> Tuple[torch.Tensor, List[int]]: """ Args: x(torch.Tensor): Input tensor. thw_shape(List): The shape of the input tensor (before flattening). """ x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape) atn = AttentionPool( pool=self.pool_skip, thw_shape=thw_shape, has_cls_embed=self.has_cls_embed ) x_res, _ = atn(x) x = x_res + self.drop_path(x_block) x_norm = self.norm2(x) x_mlp = self.mlp(x_norm) if self.dim != self.dim_out: x = self.proj(x_norm) x = x + self.drop_path(x_mlp) return x, thw_shape_new