Source code for towhee.models.drl.drl

# original code from https://github.com/foolwood/ddRL/blob/main/tvr/models/modeling.py
# modified by Zilliz

from collections import OrderedDict
from types import SimpleNamespace
import torch
import logging
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from towhee.models import clip

from towhee.models.drl.until_module import convert_weights
from towhee.models.drl.module_cross import CrossModel, Transformer as TransformerClip
from towhee.models.drl.until_module import LayerNorm, AllGather, AllGather2, CrossEn

allgather = AllGather.apply
allgather2 = AllGather2.apply

logger = logging.getLogger(__name__)


[docs]class DRL(nn.Module): """ This is a PyTorch implementation of the paper Disentangled Representation Learning for Text-Video Retrieval. Args: base_encoder (`str`): CLIP encoder backbone. default: `clip_vit_b32` agg_module (`str`): Feature aggregation module for video. default: `seqTransf`, choices=[`ndone`, `seqLSTM`, `seqTransf`] interaction (`str`): Interaction type for retrieval. default: `wti`. wti_arch (`int)`: Select a architecture for weight branch. default: 2. cdcr (`int`): Channel decorrelation regularization. default: 3. cdcr_alpha1 (`float`): Coefficient 1 for channel decorrelation regularization. default: 1.0. cdcr_alpha2 (`float`): Coefficient 2 for channel decorrelation regularization. default: 0.06. cdcr_lambda (`float`): Coefficient for channel decorrelation regularization. default: 0.001. cross_num_hidden_layers (`int`): Number of hidden layers for cross transformer interaction. """
[docs] def __init__(self, base_encoder="clip_vit_b32", agg_module="seqTransf", interaction="wti", wti_arch=2, cdcr=3, cdcr_alpha1=1.0, cdcr_alpha2=0.06, cdcr_lambda=0.001, cross_num_hidden_layers=None, backbone_pretrained=False ): super().__init__() self.base_encoder = base_encoder self.agg_module = agg_module self.interaction = interaction self.wti_arch = wti_arch self.cdcr = cdcr self.cdcr_alpha1 = cdcr_alpha1 self.cdcr_alpha2 = cdcr_alpha2 self.cdcr_lambda = cdcr_lambda self.agg_module = agg_module backbone = base_encoder self.clip = clip.create_model(model_name=backbone, pretrained=backbone_pretrained, jit=False, clip4clip=True) state_dict = self.clip.state_dict() context_length = state_dict["positional_embedding"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 if torch.cuda.is_available(): convert_weights(self.clip) # fp16 cross_config = SimpleNamespace(**{ "hidden_dropout_prob": 0.1, "hidden_size": 512, "max_position_embeddings": 128, "num_attention_heads": 8, "num_hidden_layers": 4, "vocab_size": 512, }) cross_config.max_position_embeddings = context_length cross_config.hidden_size = transformer_width self.cross_config = cross_config if self.interaction == "xti": if cross_num_hidden_layers is not None: setattr(cross_config, "num_hidden_layers", cross_num_hidden_layers) self.cross = CrossModel(cross_config) self.similarity_dense = nn.Linear(cross_config.hidden_size, 1) elif self.interaction == "mlp": self.similarity_dense = nn.Sequential(nn.Linear(transformer_width * 2, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) elif self.interaction == "wti": if self.wti_arch == 1: self.text_weight_fc = nn.Linear(transformer_width, 1) self.video_weight_fc = nn.Linear(transformer_width, 1) elif self.wti_arch == 2: self.text_weight_fc = nn.Sequential( nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) self.video_weight_fc = nn.Sequential( nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) elif self.wti_arch == 3: self.text_weight_fc = nn.Sequential( nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) self.video_weight_fc = nn.Sequential( nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, transformer_width), nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) if self.agg_module in ["seqLSTM", "seqTransf"]: self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, cross_config.hidden_size) if self.agg_module == "seqTransf": self.transformerClip = TransformerClip(width=transformer_width, # pylint: disable=invalid-name layers=cross_config.num_hidden_layers, heads=transformer_heads) if self.agg_module == "seqLSTM": self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size, batch_first=True, bidirectional=False, num_layers=1) self.loss_fct = CrossEn() self.apply(self.init_weights) # random init must before loading pretrain # ===> Initialization trick [HARdd COddE] new_state_dict = OrderedDict() if self.interaction == "xti": contain_cross = False for key in state_dict.keys(): if key.find("cross.transformer") > -1: contain_cross = True break if contain_cross is False: for key, val in state_dict.items(): if key == "positional_embedding": new_state_dict["cross.embeddings.position_embeddings.weight"] = val.clone() continue if key.find("transformer.resblocks") == 0: num_layer = int(key.split(".")[2]) # cut from beginning if num_layer < cross_config.num_hidden_layers: new_state_dict["cross." + key] = val.clone() continue if self.agg_module in ["seqLSTM", "seqTransf"]: contain_frame_position = False for key in state_dict.keys(): if key.find("frame_position_embeddings") > -1: contain_frame_position = True break if contain_frame_position is False: for key, val in state_dict.items(): if key == "positional_embedding": new_state_dict["frame_position_embeddings.weight"] = val.clone() continue if self.agg_module in ["seqTransf"] and key.find("transformer.resblocks") == 0: num_layer = int(key.split(".")[2]) # cut from beginning if num_layer < cross_config.num_hidden_layers: new_state_dict[key.replace("transformer.", "transformerClip.")] = val.clone() continue self.load_state_dict(new_state_dict, strict=False) # only update new state (seqTransf/seqLSTM/tightTransf)
# <=== End of initialization trick
[docs] def forward(self, text_ids, text_mask, video, video_mask=None): text_ids = text_ids.view(-1, text_ids.shape[-1]) text_mask = text_mask.view(-1, text_mask.shape[-1]) video_mask = video_mask.view(-1, video_mask.shape[-1]) # bd x nd_v x 3 x H x W - > (bd x nd_v) x 3 x H x W video = torch.as_tensor(video).float() b, n_v, d, h, w = video.shape video = video.view(b * n_v, d, h, w) text_feat, video_feat = self.get_text_video_feat(text_ids, video, video_mask, shaped=True) if self.training: sim_matrix1, sim_matrix2, cdcr_loss = self.get_similarity_logits(text_feat, video_feat, text_mask, video_mask, shaped=True) sim_loss = (self.loss_fct(sim_matrix1) + self.loss_fct(sim_matrix2)) / 2.0 loss = sim_loss + cdcr_loss * self.config.cdcr_lambda return loss else: return None
def get_text_feat(self, text_ids, shaped=False): if shaped is False: text_ids = text_ids.view(-1, text_ids.shape[-1]) bs_pair = text_ids.size(0) text_feat = self.clip.encode_text(text_ids, clip4clip=True, return_hidden=True)[1].float() text_feat = text_feat.view(bs_pair, -1, text_feat.size(-1)) return text_feat def get_video_feat(self, video, video_mask, shaped=False): if shaped is False: video_mask = video_mask.view(-1, video_mask.shape[-1]) video = torch.as_tensor(video).float() b, n_v, d, h, w = video.shape video = video.view(b * n_v, d, h, w) bs_pair = video_mask.size(0) video_feat = self.clip.encode_image(video).float() video_feat = video_feat.float().view(bs_pair, -1, video_feat.size(-1)) video_feat = self.aggvideo_feat(video_feat, video_mask, self.agg_module) return video_feat def get_text_video_feat(self, text_ids, video, video_mask, shaped=False): if shaped is False: text_ids = text_ids.view(-1, text_ids.shape[-1]) # text_mask = text_mask.view(-1, text_mask.shape[-1]) video_mask = video_mask.view(-1, video_mask.shape[-1]) video = torch.as_tensor(video).float() b, n_v, d, h, w = video.shape video = video.view(b * n_v, d, h, w) text_feat = self.get_text_feat(text_ids, shaped=True) video_feat = self.get_video_feat(video, video_mask, shaped=True) return text_feat, video_feat def get_video_avg_feat(self, video_feat, video_mask): video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) video_feat = video_feat * video_mask_un video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) video_mask_un_sum[video_mask_un_sum == 0.] = 1. video_feat = torch.sum(video_feat, dim=1) / video_mask_un_sum return video_feat def get_text_sep_feat(self, text_feat, text_mask): text_feat = text_feat.contiguous() text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :] text_feat = text_feat.unsqueeze(1).contiguous() return text_feat def aggvideo_feat(self, video_feat, video_mask, agg_module): video_feat = video_feat.contiguous() if agg_module == "ndone": pass elif agg_module == "seqLSTM": # Sequential type: LSTM video_feat_original = video_feat video_feat = pack_padded_sequence(video_feat, torch.sum(video_mask, dim=-1).cpu(), batch_first=True, enforce_sorted=False) video_feat, _ = self.lstm_visual(video_feat) if self.training: self.lstm_visual.flatten_parameters() video_feat, _ = pad_packed_sequence(video_feat, batch_first=True) video_feat = torch.cat( (video_feat, video_feat_original[:, video_feat.size(1):, ...].contiguous()), dim=1) video_feat = video_feat + video_feat_original elif agg_module == "seqTransf": # Sequential type: Transformer Encoder video_feat_original = video_feat seq_length = video_feat.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=video_feat.device) position_ids = position_ids.unsqueeze(0).expand(video_feat.size(0), -1) frame_position_embeddings = self.frame_position_embeddings(position_ids) video_feat = video_feat + frame_position_embeddings extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) video_feat = video_feat.permute(1, 0, 2) # ndLdd -> Lnddd video_feat = self.transformerClip(video_feat, extended_video_mask) video_feat = video_feat.permute(1, 0, 2) # Lnddd -> ndLdd video_feat = video_feat + video_feat_original return video_feat def dp_interaction(self, text_feat, video_feat, text_mask, video_mask): text_feat = self.get_text_sep_feat(text_feat, text_mask) # bd x 1 x dd if self.training and torch.cuda.is_available(): # batch merge here text_feat = allgather(text_feat, self.config) video_feat = allgather(video_feat, self.config) video_mask = allgather(video_mask, self.config) torch.distributed.barrier() # force sync text_feat = text_feat.squeeze(1) # bd x 1 x dd -> bd x dd text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) # bd x dd video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) video_feat = self.get_video_avg_feat(video_feat, video_mask) # bd x nd_v x dd -> bd x dd video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) retrieve_logits = torch.matmul(text_feat, video_feat.t()) if self.training: logit_scale = self.clip.logit_scale.exp() # retrieve_logits = logit_scale * retrieve_logits if self.config.cdcr != 0: z_a_norm = (text_feat - text_feat.mean(0)) / text_feat.std(0) # bdxdd z_b_norm = (video_feat - video_feat.mean(0)) / video_feat.std(0) # bdxdd # cross-correlation matrix bd, dd = z_a_norm.shape c = torch.einsum("bm,bn->mn", z_a_norm, z_b_norm) / bd # ddxdd # loss on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = c.flatten()[1:].view(dd - 1, dd + 1)[:, :-1].pow_(2).sum() cdcr_loss = (on_diag * self.config.cdcr_alpha1 + off_diag * self.config.cdcr_alpha2) return retrieve_logits, retrieve_logits.T, cdcr_loss else: return retrieve_logits, retrieve_logits.T, 0.0 else: return retrieve_logits, retrieve_logits.T, 0.0 def _get_cross_feat(self, text_feat, video_feat, text_mask, video_mask): concat_feats = torch.cat((text_feat, video_feat), dim=1) # concatnate tokens and frames concat_mask = torch.cat((text_mask, video_mask), dim=1) text_type_ = torch.zeros_like(text_mask) video_type_ = torch.ones_like(video_mask) concat_type = torch.cat((text_type_, video_type_), dim=1) cross_layers, pooled_feat = self.cross(concat_feats, concat_type, concat_mask, output_all_encoded_layers=True) cross_feat = cross_layers[-1] return cross_feat, pooled_feat, concat_mask def xti_interaction(self, text_feat, video_feat, text_mask, video_mask): text_feat = self.get_text_sep_feat(text_feat, text_mask) # bd x 1 x dd b_text, s_text, d_text = text_feat.size() b_video, s_video, d_video = video_feat.size() if self.training and torch.cuda.is_available(): # batch merge here text_feat_full = allgather2(text_feat, self.config) video_feat_full = allgather2(video_feat, self.config) video_mask_full = allgather2(video_mask, self.config) text_feat = text_feat_full[b_text * self.config.local_rank: b_text * (1 + self.config.local_rank)] video_feat = video_feat_full[b_video * self.config.local_rank: b_video * (1 + self.config.local_rank)] torch.distributed.barrier() # force sync else: text_feat_full = text_feat video_feat_full = video_feat video_mask_full = video_mask b_text_full = text_feat_full.shape[0] b_video_full = video_feat_full.shape[0] text_mask = torch.ones(text_feat.size(0), 1).to(device=text_mask.device, dtype=text_mask.dtype) text_mask_full = torch.ones(text_feat_full.size(0), 1).to(device=text_mask.device, dtype=text_mask.dtype) # tV text_feat_1 = text_feat.unsqueeze(1).repeat(1, b_video_full, 1, 1) # b_t x bd_v x n_t x d_t text_feat_1 = text_feat_1.view(-1, s_text, d_text) # (b_t x bd_v) x n_t x d_t text_mask_1 = text_mask.unsqueeze(1).repeat(1, b_video_full, 1) # b_t x bd_v x 1 text_mask_1 = text_mask_1.view(-1, s_text) # (b_t x bd_v) x 1 video_feat_1 = video_feat_full.unsqueeze(0).repeat(b_text, 1, 1, 1) # b_t x bd_v x n_v x d_t video_feat_1 = video_feat_1.view(-1, s_video, d_video) # (b_t x bd_v) x n_v x d_v video_mask_1 = video_mask_full.unsqueeze(0).repeat(b_text, 1, 1) # b_t x bd_v x n_v video_mask_1 = video_mask_1.view(-1, s_video) # (b_t x bd_v) x n_v # vT text_feat_2 = text_feat_full.unsqueeze(1).repeat(1, b_video, 1, 1) # bd_t x b_v x n_t x d_t text_feat_2 = text_feat_2.view(-1, s_text, d_text) # (bd_t x b_v) x n_t x d_t text_mask_2 = text_mask_full.unsqueeze(1).repeat(1, b_video, 1) # bd_t x b_v x 1 text_mask_2 = text_mask_2.view(-1, s_text) # (bd_t x b_v) x 1 video_feat_2 = video_feat.unsqueeze(0).repeat(b_text_full, 1, 1, 1) # bd_t x b_v x n_v x d_v video_feat_2 = video_feat_2.view(-1, s_video, d_video) # (bd_t x b_v) x n_v x d_t video_mask_2 = video_mask.unsqueeze(0).repeat(b_text_full, 1, 1) # bd_t x b_v x n_v video_mask_2 = video_mask_2.view(-1, s_video) # (bd_t x b_v) x n_v _, pooled_feat, _ = \ self._get_cross_feat(text_feat_1, video_feat_1, text_mask_1, video_mask_1) retrieve_logits_tv = self.similarity_dense(pooled_feat).squeeze(-1).view(b_text, b_video_full) _, pooled_feat, _ = \ self._get_cross_feat(text_feat_2, video_feat_2, text_mask_2, video_mask_2) retrieve_logits_vt = self.similarity_dense(pooled_feat).squeeze(-1).view(b_text_full, b_video).T if self.training: logit_scale = self.clip.logit_scale.exp() # retrieve_logits_tv = torch.roll(retrieve_logits_tv, -b_text * self.config.local_rank, -1) retrieve_logits_vt = torch.roll(retrieve_logits_vt, -b_video * self.config.local_rank, -1) retrieve_logits_tv = logit_scale * retrieve_logits_tv retrieve_logits_vt = logit_scale * retrieve_logits_vt return retrieve_logits_tv, retrieve_logits_vt, 0.0 else: return retrieve_logits_tv, retrieve_logits_vt, 0.0 def wti_interaction(self, text_feat, video_feat, text_mask, video_mask): if self.training and torch.cuda.is_available(): # batch merge here text_feat = allgather(text_feat, self.config) video_feat = allgather(video_feat, self.config) text_mask = allgather(text_mask, self.config) video_mask = allgather(video_mask, self.config) torch.distributed.barrier() # force sync if self.config.interaction == "wti": text_weight = self.text_weight_fc(text_feat).squeeze(2) # bd x nd_t x dd -> bd x nd_t text_weight.masked_fill_(torch.tensor((1 - text_mask), dtype=torch.bool), float("-inf")) # pylint: disable=not-callable text_weight = torch.softmax(text_weight, dim=-1) # bd x nd_t video_weight = self.video_weight_fc(video_feat).squeeze(2) # bd x nd_v x dd -> bd x nd_v video_weight.masked_fill_(torch.tensor((1 - video_mask), dtype=torch.bool), float("-inf")) # pylint: disable=not-callable video_weight = torch.softmax(video_weight, dim=-1) # bd x nd_v text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) retrieve_logits = torch.einsum("atd,bvd->abtv", [text_feat, video_feat]) retrieve_logits = torch.einsum("abtv,at->abtv", [retrieve_logits, text_mask]) retrieve_logits = torch.einsum("abtv,bv->abtv", [retrieve_logits, video_mask]) text_sum = text_mask.sum(-1) video_sum = video_mask.sum(-1) # max for video token if self.config.interaction == "ti": # token-wise interaction t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv t2v_logits = torch.sum(t2v_logits, dim=2) / (text_sum.unsqueeze(1)) v2t_logits = torch.sum(v2t_logits, dim=2) / (video_sum.unsqueeze(0)) retrieve_logits = (t2v_logits + v2t_logits) / 2.0 elif self.config.interaction == "wti": # weighted token-wise interaction t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt t2v_logits = torch.einsum("abt,at->ab", [t2v_logits, text_weight]) v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv v2t_logits = torch.einsum("abv,bv->ab", [v2t_logits, video_weight]) retrieve_logits = (t2v_logits + v2t_logits) / 2.0 if self.training: logit_scale = self.clip.logit_scale.exp() retrieve_logits = logit_scale * retrieve_logits if self.config.cdcr == 1: # simple random text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.randint_like(text_sum, 0, 10000) % text_sum, :] video_feat = video_feat[torch.arange(video_feat.shape[0]), torch.randint_like(video_sum, 0, 10000) % video_sum, :] z_a_norm = (text_feat - text_feat.mean(0)) / text_feat.std(0) # ndxnd_sxdd z_b_norm = (video_feat - video_feat.mean(0)) / video_feat.std(0) # ndxnd_txdd # cross-correlation matrix bd, dd = z_a_norm.shape c = torch.einsum("ac,ad->cd", z_a_norm, z_b_norm) / bd # ddxdd # loss on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = c.flatten()[1:].view(dd - 1, dd + 1)[:, :-1].pow_(2).sum() cdcr_loss = (on_diag * self.config.cdcr_alpha1 + off_diag * self.config.cdcr_alpha2) return retrieve_logits, retrieve_logits.T, cdcr_loss elif self.config.cdcr == 2: # selecet max max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), torch.arange(max_idx1.shape[1])] max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), torch.arange(max_idx2.shape[1])] max_t_feat = text_feat[torch.arange(max_idx2.shape[0]).repeat_interleave(max_idx2.shape[1]), max_idx2.flatten()] max_v_feat = video_feat[torch.arange(max_idx1.shape[0]).repeat_interleave(max_idx1.shape[1]), max_idx1.flatten()] t_feat = text_feat.reshape(-1, text_feat.shape[-1]) t_mask = text_mask.flatten().type(torch.bool) v_feat = video_feat.reshape(-1, text_feat.shape[-1]) v_mask = video_mask.flatten().type(torch.bool) t_feat = t_feat[t_mask] v_feat = v_feat[v_mask] max_t_feat = max_t_feat[v_mask] max_v_feat = max_v_feat[t_mask] z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (bdxnd_t)xdd z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std(0) # (bdxnd_t)xdd x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (bdxnd_v)xdd x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std(0) # (bdxnd_v)xdd # cross-correlation matrix nd, dd = z_a_norm.shape c1 = torch.einsum("ac,ad->cd", z_a_norm, z_b_norm) / nd # ddxdd nd, dd = x_a_norm.shape c2 = torch.einsum("ac,ad->cd", x_a_norm, x_b_norm) / nd # ddxdd c = (c1 + c2) / 2.0 # loss on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = c.flatten()[1:].view(dd - 1, dd + 1)[:, :-1].pow_(2).sum() cdcr_loss = (on_diag * self.config.cdcr_alpha1 + off_diag * self.config.cdcr_alpha2) return retrieve_logits, retrieve_logits.T, cdcr_loss elif self.config.cdcr == 3: # selecet max max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), torch.arange(max_idx1.shape[1])] max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), torch.arange(max_idx2.shape[1])] max_t_feat = text_feat[torch.arange(max_idx2.shape[0]).repeat_interleave(max_idx2.shape[1]), max_idx2.flatten()].squeeze(1) max_v_feat = video_feat[torch.arange(max_idx1.shape[0]).repeat_interleave(max_idx1.shape[1]), max_idx1.flatten()].squeeze(1) t_feat = text_feat.reshape(-1, text_feat.shape[-1]) t_mask = text_mask.flatten().type(torch.bool) v_feat = video_feat.reshape(-1, video_feat.shape[-1]) v_mask = video_mask.flatten().type(torch.bool) t_feat = t_feat[t_mask] v_feat = v_feat[v_mask] max_t_feat = max_t_feat[v_mask] max_v_feat = max_v_feat[t_mask] text_weight = text_weight.flatten()[t_mask] video_weight = video_weight.flatten()[v_mask] z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (bdxnd_t)xdd z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std(0) # (bdxnd_t)xdd x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (bdxnd_v)xdd x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std(0) # (bdxnd_v)xdd # cross-correlation matrix nd, dd = z_a_norm.shape bd = text_feat.shape[0] c1 = torch.einsum("acd,a->cd", torch.einsum("ac,ad->acd", z_a_norm, z_b_norm), text_weight) / bd # ddxdd c2 = torch.einsum("acd,a->cd", torch.einsum("ac,ad->acd", x_a_norm, x_b_norm), video_weight) / bd # ddxdd c = (c1 + c2) / 2.0 # loss on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = c.flatten()[1:].view(dd - 1, dd + 1)[:, :-1].pow_(2).sum() cdcr_loss = (on_diag * self.config.cdcr_alpha1 + off_diag * self.config.cdcr_alpha2) return retrieve_logits, retrieve_logits.T, cdcr_loss else: return retrieve_logits, retrieve_logits.T, 0.0 else: return retrieve_logits, retrieve_logits.T, 0.0 def get_similarity_logits(self, text_feat, video_feat, text_mask, video_mask, shaped=False): if shaped is False: text_mask = text_mask.view(-1, text_mask.shape[-1]) video_mask = video_mask.view(-1, video_mask.shape[-1]) if self.interaction == "dp": t2v_logits, v2t_logits, cdcr_loss = self.dp_interaction(text_feat, video_feat, text_mask, video_mask) elif self.interaction == "xti": t2v_logits, v2t_logits, cdcr_loss = self.xti_interaction(text_feat, video_feat, text_mask, video_mask) elif self.interaction in ["ti", "wti"]: t2v_logits, v2t_logits, cdcr_loss = self.wti_interaction(text_feat, video_feat, text_mask, video_mask) else: raise NotImplementedError return t2v_logits, v2t_logits, cdcr_loss @property def dtype(self): """ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ try: return next(self.parameters()).dtype except StopIteration: # For nn.ddataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: nn.Module): tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = self._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype
[docs] def init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, LayerNorm): if "beta" in dir(module) and "gamma" in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs]def create_model( base_encoder: str = "clip_vit_b32", agg_module: str = "seqTransf", interaction: str = "wti", wti_arch: int = 2, cdcr: int = 3, cdcr_alpha1: float = 1.0, cdcr_alpha2: float = 0.06, cdcr_lambda: float = 0.001, cross_num_hidden_layers: int = None, pretrained: bool = False, weights_path: str = None, device: str = None ) -> DRL: """ Build a DRL model. Args: base_encoder (`str`): Base_encoder in DRL model, `clip_vit_b32` or `clip_vit_b16`. agg_module (`str`): Feature aggregation module for video. default: `seqTransf`, choices=[`ndone`, `seqLSTM`, `seqTransf`] interaction (`str`): Interaction type for retrieval. default: `wti`. wti_arch (`int`): Select an architecture for weight branch. default: 2. cdcr (`int`): Channel decorrelation regularization. default: 3. cdcr_alpha1 (`float`): Coefficient 1 for channel decorrelation regularization. default: 1.0. cdcr_alpha2 (`float`): Coefficient 2 for channel decorrelation regularization. default: 0.06. cdcr_lambda (`float`): Coefficient for channel decorrelation regularization. default: 0.001. cross_num_hidden_layers (`int`): Number of hidden layers for cross transformer interaction. pretrained (`bool`): Whether model is pretrained, default if False. weights_path (`str`): Pretrained model local path, default if None. device (`str`): Model device. `cpu` or `cuda`. Returns: >>> from towhee.models import drl >>> model = drl.create_model("clip_vit_b32") >>> model.__class__.__name__ 'DRL' """ model = DRL(base_encoder=base_encoder, agg_module=agg_module, interaction=interaction, wti_arch=wti_arch, cdcr=cdcr, cdcr_alpha1=cdcr_alpha1, cdcr_alpha2=cdcr_alpha2, cdcr_lambda=cdcr_lambda, cross_num_hidden_layers=cross_num_hidden_layers) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) if pretrained and weights_path is not None: state_dict = torch.load(weights_path, map_location=device) missing_keys = [] unexpected_keys = [] error_msgs = [] metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # pylint: disable=protected-access def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( # pylint: disable=protected-access state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): # pylint: disable=protected-access if child is not None: load(child, prefix + name + ".") load(model, prefix="") if len(missing_keys) > 0: logger.info("Weights of %s not initialized from pretrained model: %s", model.__class__.__name__, "\n " + "\n ".join(missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in %s: %s", model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)) if len(error_msgs) > 0: logger.error("Weights from pretrained model cause errors in %s: %s", model.__class__.__name__, "\n " + "\n ".join(error_msgs)) if pretrained and weights_path is None: raise ValueError("weights_path is None") return model