Source code for towhee.models.perceiver.create_cross_attention

# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz


import os

try:
    # pylint: disable=unused-import
    import fairscale
except ImportError:
    os.system("pip install fairscale")
from fairscale.nn import checkpoint_wrapper

from towhee.models.perceiver.residual import Residual
from towhee.models.perceiver.sequential import Sequential
from towhee.models.perceiver.mlp import mlp
from towhee.models.perceiver.cross_attention import CrossAttention


[docs]def cross_attention_layer( num_q_channels: int, num_kv_channels: int, num_heads: int, dropout: float, activation_checkpoint: bool = False ): """ Cross attention block for Perceiver https://arxiv.org/pdf/2103.03206.pdf. Args: num_q_channels (`int`): Number of q channels. num_kv_channels (`int`): Number of k or v channels. k has the same channels as v. num_heads (`int`): Number of parallel attention heads. dropout (`nn.Module`): Dropout probability. activation_checkpoint (`bool`): Use activation checkpointing. Return (`nn.Module`): Configured cross attention layer. """ layer = Sequential( Residual(CrossAttention(num_q_channels, num_kv_channels, num_heads, dropout), dropout), Residual(mlp(num_q_channels), dropout), ) return layer if not activation_checkpoint else checkpoint_wrapper(layer)