Source code for towhee.models.perceiver.multi_head_attention

# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz


from torch import nn


[docs]class MultiHeadAttention(nn.Module): """ Multi head attention for Perceiver https://arxiv.org/pdf/2103.03206.pdf. Args: num_q_channels (`int`): Number of q channels. num_kv_channels (`int`): Number of k or v channels. k has the same channels as v. num_heads (`int`): Number of parallel attention heads. dropout (`nn.Module`): Dropout probability. """
[docs] def __init__(self, num_q_channels: int, num_kv_channels: int, num_heads: int, dropout: float): super().__init__() self.attention = nn.MultiheadAttention( embed_dim=num_q_channels, num_heads=num_heads, kdim=num_kv_channels, vdim=num_kv_channels, dropout=dropout, batch_first=True, )
[docs] def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None): """ Forward function. Args: x_q (`Tensor`): Query embeddings. x_kv (`Tensor`): Key embeddings. Key equals value. pad_mask (`int`): Padding mask. attn_mask (`nn.Module`): Attention mask. """ return self.attention(x_q, x_kv, x_kv, key_padding_mask=pad_mask, attn_mask=attn_mask)[0]