Source code for towhee.models.perceiver.self_attention

# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz


from torch import nn

from towhee.models.perceiver.multi_head_attention import MultiHeadAttention


[docs]class SelfAttention(nn.Module): """ Self attention for Perceiver https://arxiv.org/pdf/2103.03206.pdf. Args: num_channels (`int`): Number of channels. num_heads (`int`): Number of parallel attention heads. dropout (`nn.Module`): Dropout probability. """
[docs] def __init__(self, num_channels: int, num_heads: int, dropout: float): super().__init__() self.norm = nn.LayerNorm(num_channels) self.attention = MultiHeadAttention( num_q_channels=num_channels, num_kv_channels=num_channels, num_heads=num_heads, dropout=dropout )
[docs] def forward(self, x, pad_mask=None, attn_mask=None): """ Forward function. Args: x (`Tensor`): Input tensor pad_mask (`Tensor`): Padding mask. attn_mask (`Tensor`): Attention mask. """ x = self.norm(x) return self.attention(x, x, pad_mask=pad_mask, attn_mask=attn_mask)