# Pytorch implementation is adapted from: https://github.com/lucidrains/CoCa-pytorch
#
# All modifications are made by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange
from towhee.models.layers.activations.swiglu import SwiGLU
[docs]class LayerNorm(nn.Module):
[docs] def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
[docs] def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
[docs]class Residual(nn.Module):
[docs] def __init__(self, fn):
super().__init__()
self.fn = fn
[docs] def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
[docs]class RotaryEmbedding(nn.Module):
[docs] def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
[docs] def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
[docs]def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
[docs]def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
[docs]class CrossAttention(nn.Module):
"""
cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward.
"""
[docs] def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
if context_dim is None:
context_dim = dim
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether to have parallel feedforward
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
[docs] def forward(self, x, context):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# pre-layernorm, for queries and context
x = self.norm(x)
context = self.context_norm(context)
# get queries
q = self.to_q(x)
q = rearrange(q, "b n (h d) -> b h n d", h = self.heads)
# scale
q = q * self.scale
# get key / values
k, v = self.to_kv(context).chunk(2, dim=-1)
# query / key similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True)
attn = sim.softmax(dim=-1)
# aggregate
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge and combine heads
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
# add parallel feedforward (for multimodal layers)
if self.ff is not None:
# pylint: disable=E1102
out = out + self.ff(x)
return out