Source code for towhee.models.perceiver.create_self_attention

# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz


import os

try:
    # pylint: disable=unused-import
    import fairscale
except ImportError:
    os.system("pip install fairscale")
from fairscale.nn import checkpoint_wrapper

from towhee.models.perceiver.sequential import Sequential
from towhee.models.perceiver.residual import Residual
from towhee.models.perceiver.self_attention import SelfAttention
from towhee.models.perceiver.mlp import mlp


[docs]def create_self_attention(num_channels: int, num_heads: int, dropout: float, activation_checkpoint: bool = False): """ Self attention block for Perceiver https://arxiv.org/pdf/2103.03206.pdf. Args: num_channels (`int`): Number of q channels. num_heads (`int`): Number of parallel attention heads. dropout (`nn.Module`): Dropout probability. activation_checkpoint (`bool`): Use activation checkpointing. Return (`nn.Module`): Configured self attention layer. """ layer = Sequential( Residual(SelfAttention(num_channels, num_heads, dropout), dropout), Residual(mlp(num_channels), dropout) ) return layer if not activation_checkpoint else checkpoint_wrapper(layer)