Source code for towhee.models.layers.cross_attention

# Pytorch implementation is adapted from: https://github.com/lucidrains/CoCa-pytorch
#
# All modifications are made by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import einsum, nn
import torch.nn.functional as F

from einops import rearrange
from towhee.models.layers.activations.swiglu import SwiGLU

[docs]class LayerNorm(nn.Module):
[docs] def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer("beta", torch.zeros(dim))
[docs] def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
[docs]class Residual(nn.Module):
[docs] def __init__(self, fn): super().__init__() self.fn = fn
[docs] def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x
[docs]class RotaryEmbedding(nn.Module):
[docs] def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq)
[docs] def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) freqs = einsum("i , j -> i j", seq, self.inv_freq) return torch.cat((freqs, freqs), dim=-1)
[docs]def rotate_half(x): x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1)
[docs]def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin())
[docs]class CrossAttention(nn.Module): """ cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward. """
[docs] def __init__( self, dim, *, context_dim=None, dim_head=64, heads=8, parallel_ff=False, ff_mult=4, norm_context=False ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 inner_dim = heads * dim_head if context_dim is None: context_dim = dim self.norm = LayerNorm(dim) self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) # whether to have parallel feedforward ff_inner_dim = ff_mult * dim self.ff = nn.Sequential( nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False) ) if parallel_ff else None
[docs] def forward(self, x, context): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ # pre-layernorm, for queries and context x = self.norm(x) context = self.context_norm(context) # get queries q = self.to_q(x) q = rearrange(q, "b n (h d) -> b h n d", h = self.heads) # scale q = q * self.scale # get key / values k, v = self.to_kv(context).chunk(2, dim=-1) # query / key similarity sim = einsum("b h i d, b j d -> b h i j", q, k) # attention sim = sim - sim.amax(dim=-1, keepdim=True) attn = sim.softmax(dim=-1) # aggregate out = einsum("b h i j, b j d -> b h i d", attn, v) # merge and combine heads out = rearrange(out, "b h n d -> b n (h d)") out = self.to_out(out) # add parallel feedforward (for multimodal layers) if self.ff is not None: # pylint: disable=E1102 out = out + self.ff(x) return out