Source code for towhee.models.action_clip.visual_prompt

# Code for "ActionCLIP: ActionCLIP: A New Paradigm for Action Recognition"
# Mengmeng Wang, Jiazheng Xing, Yong Liu
#
# Built on top of official implementation at https://github.com/sallymmx/ActionCLIP
#
# Modifications by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from collections import OrderedDict

from towhee.models.utils.weight_init import trunc_normal_


[docs]class LayerNorm(nn.Module): """ Construct a layernorm module in the TF style (epsilon inside the square root). """
[docs] def __init__(self, hidden_size, eps=1e-12): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps
[docs] def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias
[docs]class QuickGELU(nn.Module): """ QuickGELU """
[docs] def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x)
[docs]class ResidualAttentionBlock(nn.Module): """ ResidualAttentionBlock """
[docs] def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), ('c_proj', nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask
def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
[docs] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x
[docs]class TAggregate(nn.Module): """ TAggregate """
[docs] def __init__(self, clip_length=None, embed_dim=2048, n_layers=6): super().__init__() self.clip_length = clip_length drop_rate = 0. enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8) self.transformer_enc = nn.TransformerEncoder(enc_layer, num_layers=n_layers, norm=nn.LayerNorm( embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, clip_length + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) with torch.no_grad(): trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Linear): with torch.no_grad(): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x): nvids = x.shape[0] cls_tokens = self.cls_token.expand(nvids, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x.transpose_(1, 0) o = self.transformer_enc(x) return o[0]
[docs]class TemporalTransformer(nn.Module): """ TemporalTransformer """
[docs] def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
[docs] def forward(self, x: torch.Tensor): return self.resblocks(x)
[docs]class VisualPrompt(nn.Module): """ VisualPrompt """
[docs] def __init__(self, sim_head, clip_state_dict, num_frames): super().__init__() self.sim_header = sim_head self.num_frames = num_frames assert sim_head in ['meanP', 'LSTM', 'Transf', 'Conv_1D', 'Transf_cls'] if self.sim_header == 'LSTM'\ or self.sim_header == 'Transf'\ or self.sim_header == 'Transf_cls'\ or self.sim_header == 'Conv_1D': embed_dim = clip_state_dict['text_projection'].shape[1] context_length = clip_state_dict['positional_embedding'].shape[0] # vocab_size = clip_state_dict['token_embedding.weight'].shape[0] transformer_width = clip_state_dict['ln_final.weight'].shape[0] transformer_heads = transformer_width // 64 # transformer_layers = len( # set(k.split('.')[2] for k in clip_state_dict if k.startswith('transformer.resblocks'))) self.frame_position_embeddings = nn.Embedding(context_length, embed_dim) if self.sim_header == 'Transf': self.transformer = TemporalTransformer(width=embed_dim, layers=6, heads=transformer_heads) if self.sim_header == 'LSTM': self.lstm_visual = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, batch_first=True, bidirectional=False, num_layers=1) self.apply(self.init_weights) if self.sim_header == 'Transf_cls': self.transformer = TAggregate(clip_length=self.num_frames, embed_dim=embed_dim, n_layers=6) if self.sim_header == 'Conv_1D': self.shift = nn.Conv1d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False) weight = torch.zeros(embed_dim, 1, 3) weight[:embed_dim // 4, 0, 0] = 1.0 weight[embed_dim // 4:embed_dim // 4 + embed_dim // 2, 0, 1] = 1.0 weight[-embed_dim // 4:, 0, 2] = 1.0 self.shift.weight = nn.Parameter(weight)
[docs] def init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, LayerNorm): if 'beta' in dir(module) and 'gamma' in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs] def forward(self, x): _, t, c = x.size() x = x.contiguous() if self.sim_header == 'meanP': pass elif self.sim_header == 'Conv_1D': x_original = x x = x.view(-1, c, t) x = self.shift(x.float()) x = x.permute(0, 2, 1) x = x.type(x_original.dtype) + x_original elif self.sim_header == 'Transf': x_original = x seq_length = t position_ids = torch.arange(seq_length, dtype=torch.long, device=x.device) position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1) frame_position_embeddings = self.frame_position_embeddings(position_ids) x = x + frame_position_embeddings x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = x.type(x_original.dtype) + x_original elif self.sim_header == 'LSTM': x_original = x x, _ = self.lstm_visual(x.float()) self.lstm_visual.flatten_parameters() x = torch.cat((x, x_original[:, x.size(1):, ...].contiguous()), dim=1) x = x.type(x_original.dtype) + x_original elif self.sim_header == 'Transf_cls': x_original = x return self.transformer(x).type(x_original.dtype) else: raise ValueError(f'Unknown optimizer: {self.sim_header}') return x.mean(dim=1, keepdim=False)