# Pytorch implementation is adapted from: https://github.com/lucidrains/CoCa-pytorch
#
# All modifications are made by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat
from towhee.models.layers.cross_attention import LayerNorm, Residual, RotaryEmbedding, apply_rotary_pos_emb, \
CrossAttention
from towhee.models.layers.activations.swiglu import SwiGLU
[docs]class CoCa(nn.Module):
"""
CoCa model.
Args:
dim (`int`):
model dimension
img_encoder (`nn.Module`):
vision transformer - image encoder, returning image embeddings as (batch, seq, dim).
image_dim (`int`):
image embedding dimension, if not the same as model dimensions.
num_tokens (`int`):
number of text tokens.
unimodal_depth (`int`):
depth of the unimodal transformer.
multimodal_depth (`int`):
depth of the multimodal transformer.
dim_head (`int`):
dimension per attention head.
heads (`int`):
number of attention heads.
caption_loss_weight (`float`):
weight on the autoregressive caption loss.
contrastive_loss_weight (`float`):
weight on the contrastive loss between image and text CLS embeddings.
"""
[docs] def __init__(
self,
*,
dim,
num_tokens,
unimodal_depth,
multimodal_depth,
image_dim=None,
num_img_queries=256,
dim_head=64,
heads=8,
ff_mult=4,
img_encoder=None,
caption_loss_weight=1.,
contrastive_loss_weight=1.,
pad_id=0
):
super().__init__()
self.dim = dim
self.pad_id = pad_id
self.caption_loss_weight = caption_loss_weight
self.contrastive_loss_weight = contrastive_loss_weight
# token embeddings
self.token_emb = nn.Embedding(num_tokens, dim)
self.text_cls_token = nn.Parameter(torch.randn(dim))
# image encoder
self.img_encoder = img_encoder
# attention pooling for image tokens
self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim))
self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads,
norm_context=True)
self.img_attn_pool_norm = LayerNorm(dim)
self.text_cls_norm = LayerNorm(dim)
# contrastive learning temperature
self.temperature = nn.Parameter(torch.Tensor([1.]))
# unimodal layers
self.unimodal_layers = nn.ModuleList([])
# pylint: disable=W0612
for ind in range(unimodal_depth):
self.unimodal_layers.append(
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
)
# multimodal layers
self.multimodal_layers = nn.ModuleList([])
for ind in range(multimodal_depth):
self.multimodal_layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult))
]))
# to logits
self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
# they used embedding weight tied projection out to logits, not common, but works
self.to_logits[-1].weight = self.token_emb.weight
nn.init.normal_(self.token_emb.weight, std=0.02)
def embed_text(self, text):
batch, _ = text.shape[0], text.device
seq = text.shape[1]
text_tokens = self.token_emb(text)
# append text cls tokens
text_cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=batch)
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
# create specific mask for text cls token at the end
# to prevent it from attending to padding
cls_mask = rearrange(text != self.pad_id, "b j -> b 1 j")
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
# go through unimodal layers
for attn_ff in self.unimodal_layers:
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
# get text cls token
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
text_embeds = self.text_cls_norm(text_cls_tokens)
return text_embeds, text_tokens
def embed_image(self, images=None, image_tokens=None):
# encode images into embeddings
# with the img_encoder passed in at init
# it can also accept precomputed image tokens
assert not ((images is not None) and (image_tokens is not None))
if images is not None:
assert self.img_encoder is not None, "img_encoder must be passed in for automatic image encoding"
image_tokens = self.img_encoder(images)
# attention pool image tokens
img_queries = repeat(self.img_queries, "n d -> b n d", b=image_tokens.shape[0])
img_queries = self.img_attn_pool(img_queries, image_tokens)
img_queries = self.img_attn_pool_norm(img_queries)
return img_queries[:, 0], img_queries[:, 1:]
[docs] def forward(
self,
text,
images=None,
image_tokens=None,
labels=None,
return_loss=False,
return_embeddings=False
):
batch, device = text.shape[0], text.device
if return_loss and (labels is None):
text, labels = text[:, :-1], text[:, 1:]
text_embeds, text_tokens = self.embed_text(text)
image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens)
# return embeddings if that is what the researcher wants
if return_embeddings:
return text_embeds, image_embeds
# go through multimodal layers
for attn_ff, cross_attn in self.multimodal_layers:
text_tokens = attn_ff(text_tokens)
text_tokens = cross_attn(text_tokens, image_tokens)
logits = self.to_logits(text_tokens)
if not return_loss:
return logits
# shorthand
ce = F.cross_entropy
# calculate caption loss (cross entropy loss)
logits = rearrange(logits, "b n c -> b c n")
caption_loss = ce(logits, labels, ignore_index=self.pad_id)
caption_loss = caption_loss * self.caption_loss_weight
# calculate contrastive loss
sim = einsum("i d, j d -> i j", text_embeds, image_embeds)
sim = sim * self.temperature.exp()
contrastive_labels = torch.arange(batch, device=device)
contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
contrastive_loss = contrastive_loss * self.contrastive_loss_weight
return caption_loss + contrastive_loss