Source code for towhee.models.layers.relative_self_attention

# Original pytorch implementation by:
# 'MaxViT: Multi-Axis Vision Transformer'
#       - https://arxiv.org/pdf/2204.01697.pdf
# Original code by / Copyright 2021, Christoph Reich.
# Modifications & additions by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from towhee.models.utils.weight_init import trunc_normal_
from towhee.models.utils.get_relative_position_index import get_relative_position_index
from torch import nn
import torch


[docs]class RelativeSelfAttention(nn.Module): """ Relative Self-Attention similar to Swin V1. Implementation inspired by Timms Swin V1 implementation. Args: in_channels (int): Number of input channels. num_heads (int, optional): Number of attention heads. Default 32 grid_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7) attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 drop (float, optional): Dropout ratio of output. Default: 0.0 """
[docs] def __init__( self, in_channels: int, num_heads: int = 32, grid_window_size: Tuple[int, int] = (7, 7), attn_drop: float = 0., drop: float = 0. ) -> None: """ Constructor method """ # Call super constructor super().__init__() # Save parameters self.in_channels: int = in_channels self.num_heads: int = num_heads self.grid_window_size: Tuple[int, int] = grid_window_size self.scale: float = num_heads ** -0.5 self.attn_area: int = grid_window_size[0] * grid_window_size[1] # Init layers self.qkv_mapping = nn.Linear(in_features=in_channels, out_features=3 * in_channels, bias=True) self.attn_drop = nn.Dropout(p=attn_drop) self.proj = nn.Linear(in_features=in_channels, out_features=in_channels, bias=True) self.proj_drop = nn.Dropout(p=drop) self.softmax = nn.Softmax(dim=-1) # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * grid_window_size[0] - 1) * (2 * grid_window_size[1] - 1), num_heads)) # Get pair-wise relative position index for each token inside the window self.register_buffer("relative_position_index", get_relative_position_index(grid_window_size[0], grid_window_size[1])) # Init relative positional bias trunc_normal_(self.relative_position_bias_table, std=.02)
def _get_relative_positional_bias( self ) -> torch.Tensor: """ Returns the relative positional bias. Returns: relative_position_bias (torch.Tensor): Relative positional bias. """ relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view(self.attn_area, self.attn_area, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() return relative_position_bias.unsqueeze(0)
[docs] def forward( self, data: torch.Tensor ) -> torch.Tensor: """ Forward pass. Args: data (torch.Tensor): Input tensor of the shape [B_, N, C]. Returns: output (torch.Tensor): Output tensor of the shape [B_, N, C]. """ # Get shape of input b_, n, _ = data.shape # Perform query key value mapping qkv = self.qkv_mapping(data).reshape(b_, n, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # Scale query q = q * self.scale # Compute attention maps attn = self.softmax(q @ k.transpose(-2, -1) + self._get_relative_positional_bias()) # Map value with attention maps output = (attn @ v).transpose(1, 2).reshape(b_, n, -1) # Perform final projection and dropout output = self.proj(output) output = self.proj_drop(output) return output