Source code for towhee.models.perceiver.cross_attention

# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz


from torch import nn

from towhee.models.perceiver.multi_head_attention import MultiHeadAttention


[docs]class CrossAttention(nn.Module): """ Cross attention for Perceiver https://arxiv.org/pdf/2103.03206.pdf. Args: num_q_channels (`int`): Number of q channels. num_kv_channels (`int`): Number of k or v channels. k has the same channels as v. num_heads (`int`): Number of parallel attention heads. dropout (`nn.Module`): Dropout probability. """
[docs] def __init__(self, num_q_channels: int, num_kv_channels: int, num_heads: int, dropout: float): super().__init__() self.q_norm = nn.LayerNorm(num_q_channels) self.kv_norm = nn.LayerNorm(num_kv_channels) self.attention = MultiHeadAttention( num_q_channels=num_q_channels, num_kv_channels=num_kv_channels, num_heads=num_heads, dropout=dropout )
[docs] def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None): x_q = self.q_norm(x_q) x_kv = self.kv_norm(x_kv) return self.attention(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask)