# original code from https://github.com/krasserm/perceiver-io
# modified by Zilliz
from torch import nn
[docs]class Residual(nn.Module):
"""
Residual module for Perceiver https://arxiv.org/pdf/2103.03206.pdf.
Args:
module (`nn.Module`):
nn.Module.
dropout (`nn.Module`):
Dropout probability.
"""
[docs] def __init__(self, module: nn.Module, dropout: float):
super().__init__()
self.module = module
self.dropout = nn.Dropout(p=dropout)
self.dropout_p = dropout
[docs] def forward(self, *args, **kwargs):
x = self.module(*args, **kwargs)
return self.dropout(x) + args[0]