Source code for towhee.models.layers.window_attention3d

# original code from https://github.com/SwinTransformer/Video-Swin-Transformer
# modified by Zilliz.

from torch import nn
import torch
from towhee.models.utils.weight_init import trunc_normal_


[docs]class WindowAttention3D(nn.Module): """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. ref1: https://arxiv.org/pdf/2010.11929.pdf ref2: https://arxiv.org/pdf/2106.13230.pdf Args: dim (`int`): Number of input channels, see ref2. window_size (`tuple[int]`): The temporal length, height and width of the window, see ref2. num_heads (`int`): Number of attention heads of MSA, see ref1. qkv_bias (`bool`): If True, add a learnable bias to query, key, value. Default: True qk_scale (`float`): Override default qk scale of head_dim ** -0.5 if set attn_drop (`float`): Dropout ratio of attention weight. Default: 0.0 proj_drop (`float`): Dropout ratio of output. Default: 0.0 """
[docs] def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wd, Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # get pair-wise relative position index for each token inside the window coords_d = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[2]) coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 2] += self.window_size[2] - 1 relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1)
[docs] def forward(self, x, mask=None): """ Forward function of window attention 3d. Args: x (`tensor`): Input features with shape of (num_windows*b, n, c). See ref2, n = (t * h * w) / (p * m * m). num_windows is Wd * Wh * Ww. b is batch size. c is dimension or channels size, same as dim param in init function. mask (`tensor`): Mask with shape of (num_windows, n, n) or None """ b, n, c = x.shape qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:n, :n].reshape(-1)]\ .reshape( n, n, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N if mask is not None: nw = mask.shape[0] attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, n, n) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(b, n, c) x = self.proj(x) x = self.proj_drop(x) return x