# Original pytorch implementation by:
# 'MaxViT: Multi-Axis Vision Transformer'
# - https://arxiv.org/pdf/2204.01697.pdf
# Original code by / Copyright 2021, Christoph Reich.
# Modifications & additions by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from towhee.models.utils.weight_init import trunc_normal_
from towhee.models.utils.get_relative_position_index import get_relative_position_index
from torch import nn
import torch
[docs]class RelativeSelfAttention(nn.Module):
""" Relative Self-Attention similar to Swin V1. Implementation inspired by Timms Swin V1 implementation.
Args:
in_channels (int): Number of input channels.
num_heads (int, optional): Number of attention heads. Default 32
grid_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7)
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
drop (float, optional): Dropout ratio of output. Default: 0.0
"""
[docs] def __init__(
self,
in_channels: int,
num_heads: int = 32,
grid_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0.,
drop: float = 0.
) -> None:
""" Constructor method """
# Call super constructor
super().__init__()
# Save parameters
self.in_channels: int = in_channels
self.num_heads: int = num_heads
self.grid_window_size: Tuple[int, int] = grid_window_size
self.scale: float = num_heads ** -0.5
self.attn_area: int = grid_window_size[0] * grid_window_size[1]
# Init layers
self.qkv_mapping = nn.Linear(in_features=in_channels, out_features=3 * in_channels, bias=True)
self.attn_drop = nn.Dropout(p=attn_drop)
self.proj = nn.Linear(in_features=in_channels, out_features=in_channels, bias=True)
self.proj_drop = nn.Dropout(p=drop)
self.softmax = nn.Softmax(dim=-1)
# Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * grid_window_size[0] - 1) * (2 * grid_window_size[1] - 1), num_heads))
# Get pair-wise relative position index for each token inside the window
self.register_buffer("relative_position_index", get_relative_position_index(grid_window_size[0],
grid_window_size[1]))
# Init relative positional bias
trunc_normal_(self.relative_position_bias_table, std=.02)
def _get_relative_positional_bias(
self
) -> torch.Tensor:
""" Returns the relative positional bias.
Returns:
relative_position_bias (torch.Tensor): Relative positional bias.
"""
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.attn_area, self.attn_area, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)
[docs] def forward(
self,
data: torch.Tensor
) -> torch.Tensor:
""" Forward pass.
Args:
data (torch.Tensor): Input tensor of the shape [B_, N, C].
Returns:
output (torch.Tensor): Output tensor of the shape [B_, N, C].
"""
# Get shape of input
b_, n, _ = data.shape
# Perform query key value mapping
qkv = self.qkv_mapping(data).reshape(b_, n, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# Scale query
q = q * self.scale
# Compute attention maps
attn = self.softmax(q @ k.transpose(-2, -1) + self._get_relative_positional_bias())
# Map value with attention maps
output = (attn @ v).transpose(1, 2).reshape(b_, n, -1)
# Perform final projection and dropout
output = self.proj(output)
output = self.proj_drop(output)
return output