Source code for towhee.models.layers.swin_transformer_block3d

# original code from https://github.com/SwinTransformer/Video-Swin-Transformer
# modified by Zilliz.

import torch
from torch import nn
from torch.utils import checkpoint
import torch.nn.functional as F

from towhee.models.layers.window_attention3d import WindowAttention3D
from towhee.models.utils.window_partition3d import window_partition
from towhee.models.utils.window_reverse3d import window_reverse
from towhee.models.layers.mlp import Mlp
from towhee.models.utils.get_window_size import get_window_size
from towhee.models.layers.droppath import DropPath


[docs]class SwinTransformerBlock3D(nn.Module): """ 3D Swin Transformer Block. Args: dim (`int`): Number of input channels. num_heads (`int`): Number of attention heads. window_size (`tuple[int]`): Window size. shift_size (`tuple[int]`): Shift size for SW-MSA. mlp_ratio (`float`): Ratio of mlp hidden dim to embedding dim. qkv_bias (`bool`): If True, add a learnable bias to query, key, value. Default: True qk_scale (`float`): Override default qk scale of head_dim ** -0.5 if set. drop (`float`): Dropout rate. Default: 0.0 attn_drop (`float`): Attention dropout rate. Default: 0.0 drop_path(`float`): Stochastic depth rate. Default: 0.0 act_layer (`nn.Module`): Activation layer. Default: nn.GELU norm_layer (`nn.Module`): Normalization layer. Default: nn.LayerNorm """
[docs] def __init__(self, dim, num_heads, window_size=(2, 7, 7), shift_size=(0, 0, 0), mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.use_checkpoint = use_checkpoint assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention3D( dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward_part1(self, x, mask_matrix): b, d, h, w, c = x.shape window_size, shift_size = \ get_window_size((d, h, w), self.window_size, self.shift_size) # pylint: disable=unbalanced-tuple-unpacking x = self.norm1(x) # pad feature maps to multiples of window size pad_l = pad_t = pad_d0 = 0 pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] pad_b = (window_size[1] - h % window_size[1]) % window_size[1] pad_r = (window_size[2] - w % window_size[2]) % window_size[2] x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) _, dp, hp, wp, _ = x.shape # cyclic shift if any(i > 0 for i in shift_size): shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition(shifted_x, window_size) # b*nW, Wd*Wh*Ww, c # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # b*nW, Wd*Wh*Ww, c # merge windows attn_windows = attn_windows.view(-1, *(window_size + (c,))) shifted_x = window_reverse(attn_windows, window_size, b, dp, hp, wp) # b d' h' w' c # reverse cyclic shift if any(i > 0 for i in shift_size): x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) else: x = shifted_x if pad_d1 > 0 or pad_r > 0 or pad_b > 0: x = x[:, :d, :h, :w, :].contiguous() return x def forward_part2(self, x): return self.drop_path(self.mlp(self.norm2(x)))
[docs] def forward(self, x, mask_matrix): """ Forward function. Args: x (`tensor`): Input feature, tensor size (B, D, H, W, C). mask_matrix (`tuple[int]`): Attention mask for cyclic shift. """ shortcut = x if self.use_checkpoint: x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) else: x = self.forward_part1(x, mask_matrix) x = shortcut + self.drop_path(x) if self.use_checkpoint: x = x + checkpoint.checkpoint(self.forward_part2, x) else: x = x + self.forward_part2(x) return x