Source code for towhee.models.violet.violet

# original code from https://github.com/SwinTransformer/Video-Swin-Transformer
# modified by Zilliz.
import torch
import os
try:
    import torchvision
except ModuleNotFoundError:
    os.system("pip install torchvision")
    import torchvision
try:
    import transformers
except ModuleNotFoundError:
    os.system("pip install transformers")
    import transformers
from towhee.models.video_swin_transformer.video_swin_transformer import VideoSwinTransformer
# from video_swin import SwinTransformer3D


[docs]class EncImg(torch.nn.Module): """ EncImg module """
[docs] def __init__(self): super().__init__() self.swin = VideoSwinTransformer(patch_size=(2, 4, 4), patch_norm=True, depths=[2, 2, 18, 2], window_size=(8, 7, 7), num_heads=[3, 6, 12, 24], mlp_ratio=4., drop_path_rate=0.2, stride=(1, 4, 4)) #self.swin = SwinTransformer3D() # self.swin.load_state_dict(torch.load("./_snapshot/ckpt_video-swin.pt", map_location="cpu")) self.emb_cls = torch.nn.Parameter(0.02 * torch.randn(1, 1, 1, 768)) self.emb_pos = torch.nn.Parameter(0.02 * torch.randn(1, 1, 1 + 14 ** 2, 768)) self.emb_len = torch.nn.Parameter(0.02 * torch.randn(1, 6, 1, 768)) self.norm = torch.nn.LayerNorm(768)
[docs] def forward(self, img): bb, tt, _, hh, ww = img.shape h, w = hh // 32, ww // 32 img = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) f_img = self.swin(img.transpose(1, 2)).transpose(1, 2) f_img = f_img.permute(0, 1, 3, 4, 2).view([bb, tt, h * w, 768]) f_img = torch.cat([self.emb_cls.expand([bb, tt, -1, -1]), f_img], dim=2) f_img += self.emb_pos.expand([bb, tt, -1, -1])[:, :, :1 + h * w, :] + self.emb_len.expand( [bb, -1, 1 + h * w, -1])[:, :tt, :, :] f_img = self.norm(f_img).view([bb, tt * (1 + h * w), -1]) m_img = torch.ones(1 + h * w).long().unsqueeze(0).unsqueeze(0) m_img = m_img.expand([bb, tt, -1]).contiguous().view([bb, tt * (1 + h * w)]) return f_img, m_img
[docs]class EncTxt(torch.nn.Module): """ EncTxt module """
[docs] def __init__(self): super().__init__() bert = transformers.BertModel.from_pretrained("bert-base-uncased") self.emb_txt = bert.embeddings
[docs] def forward(self, txt): f_txt = self.emb_txt(txt) return f_txt
[docs]class VioletBase(torch.nn.Module): """ VioletBase module """
[docs] def __init__(self): super().__init__() self.enc_img, self.enc_txt = EncImg(), EncTxt() bert = transformers.BertForMaskedLM.from_pretrained("bert-base-uncased") self.mask_ext, self.trsfr = bert.get_extended_attention_mask, bert.bert.encoder
def go_feat(self, img, txt, mask): feat_img, mask_img = self.enc_img(img) feat_txt, mask_txt = self.enc_txt(txt), mask return feat_img, mask_img, feat_txt, mask_txt def go_cross(self, feat_img, mask_img, feat_txt, mask_txt): feat, mask = torch.cat([feat_img, feat_txt], dim=1), torchvision.cat([mask_img, mask_txt], dim=1) mask = self.mask_ext(mask, mask.shape, mask.device) out = self.trsfr(feat, mask, output_attentions=True) return out["last_hidden_state"], out["attentions"] def load_ckpt(self, ckpt): if ckpt == "": print("===== Init VIOLET =====") return ckpt_new, ckpt_old = torch.load(ckpt, map_location="cpu"), self.state_dict() key_old = set(ckpt_old.keys()) for k in ckpt_new: if k in ckpt_old and ckpt_new[k].shape == ckpt_old[k].shape: ckpt_old[k] = ckpt_new[k] key_old.remove(k) self.load_state_dict(ckpt_old) print("===== Not Load:", key_old, "=====")