# original code from https://github.com/SwinTransformer/Video-Swin-Transformer
# modified by Zilliz.
from torch import nn
import torch
from towhee.models.utils.weight_init import trunc_normal_
[docs]class WindowAttention3D(nn.Module):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
ref1: https://arxiv.org/pdf/2010.11929.pdf
ref2: https://arxiv.org/pdf/2106.13230.pdf
Args:
dim (`int`):
Number of input channels, see ref2.
window_size (`tuple[int]`):
The temporal length, height and width of the window, see ref2.
num_heads (`int`):
Number of attention heads of MSA, see ref1.
qkv_bias (`bool`):
If True, add a learnable bias to query, key, value. Default: True
qk_scale (`float`):
Override default qk scale of head_dim ** -0.5 if set
attn_drop (`float`):
Dropout ratio of attention weight. Default: 0.0
proj_drop (`float`):
Dropout ratio of output. Default: 0.0
"""
[docs] def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wd, Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
# 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))
# get pair-wise relative position index for each token inside the window
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
[docs] def forward(self, x, mask=None):
"""
Forward function of window attention 3d.
Args:
x (`tensor`):
Input features with shape of (num_windows*b, n, c). See ref2, n = (t * h * w) / (p * m * m).
num_windows is Wd * Wh * Ww.
b is batch size. c is dimension or channels size, same as dim param in init function.
mask (`tensor`):
Mask with shape of (num_windows, n, n) or None
"""
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:n, :n].reshape(-1)]\
.reshape(
n, n, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x