Source code for towhee.models.timesformer.timesformer_block

# Reference:
#   [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095)
#
# Built on top of codes from / Copyright (c) Facebook, Inc. and its affiliates.
# Modifications & additions by / Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from torch import nn

try:
    from einops import rearrange
except ModuleNotFoundError:
    os.system('pip install einops')
    from einops import rearrange

from towhee.models.layers.attention import MultiHeadAttention
from towhee.models.layers.droppath import DropPath
from towhee.models.layers.mlp import Mlp


[docs]class Block(nn.Module): """ TimeSformer block. Args: dim (`int`): feature dimension num_heads (`int`): number of attention heads mlp_ratio (`float`): mlp ratio qkv_bias (`bool=False`): if use qkv_bias qk_scale (`float=None`): number to scale qk drop (`float=0.`): drop rate of projection in attention & MLP layers attn_drop (`float=0.`): attention drop rate drop_path (`float=0.1`): drop probability in DropPath act_layer (`nn.module=nn.GELU`): module used in activation layer norm_layer (`nn.module=nn.LayerNorm`): module used in normalization layer attention_type (`str='divided_space_time`): type of TimeSformer attention from ['divided_space_time', 'space_only', 'joint_space_time'] Example: >>> import torch >>> from towhee.models.timesformer.timesformer_block import Block >>> >>> test_shape = (1, 196*8+1, 768) >>> fake_x = torch.rand(test_shape) >>> model = Block(dim=768, num_heads=12) >>> out = model(fake_x, b=1, t=8, w=int(224/16)) >>> print(out.shape) torch.Size([1, 1569, 768]) """
[docs] def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type='divided_space_time'): super().__init__() self.attention_type = attention_type assert(attention_type in ['divided_space_time', 'space_only', 'joint_space_time', 'frozen_in_time']) self.norm1 = norm_layer(dim) self.attn = MultiHeadAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop, proj_drop_ratio=drop) # Temporal Attention Parameters if self.attention_type in ['divided_space_time', 'frozen_in_time']: self.temporal_norm1 = norm_layer(dim) self.temporal_attn = MultiHeadAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop, proj_drop_ratio=drop) self.temporal_fc = nn.Linear(dim, dim) # Drop path self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
[docs] def forward(self, x, b, t, w): num_spatial_tokens = (x.size(1) - 1) // t h = num_spatial_tokens // w if self.attention_type in ['space_only', 'joint_space_time']: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x elif self.attention_type in ['divided_space_time', 'frozen_in_time'] : # Temporal original_input = x[:, 1:, :] xt = x[:, 1:, :] xt = rearrange(xt, 'b (h w t) m -> (b h w) t m', b=b, h=h, w=w, t=t) res_temporal = self.drop_path(self.temporal_attn(self.temporal_norm1(xt))) res_temporal = rearrange(res_temporal, '(b h w) t m -> b (h w t) m', b=b, h=h, w=w, t=t) res_temporal = self.temporal_fc(res_temporal) xt = x[:, 1:, :] + res_temporal # Spatial init_cls_token = x[:, 0, :].unsqueeze(1) cls_token = init_cls_token.repeat(1, t, 1) cls_token = rearrange(cls_token, 'b t m -> (b t) m', b=b, t=t).unsqueeze(1) xs = xt xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m', b=b, h=h, w=w, t=t) xs = torch.cat((cls_token, xs), 1) res_spatial = self.drop_path(self.attn(self.norm1(xs))) # CLS token cls_token = res_spatial[:, 0, :] cls_token = rearrange(cls_token, '(b t) m -> b t m', b=b, t=t) cls_token = torch.mean(cls_token, 1, True) # averaging for every frame res_spatial = res_spatial[:, 1:, :] res_spatial = rearrange(res_spatial, '(b t) (h w) m -> b (h w t) m', b=b, h=h, w=w, t=t) res = res_spatial x = xt # Mlp if self.attention_type == 'frozen_in_time': x = torch.cat((init_cls_token, original_input), 1) + torch.cat((cls_token, res), 1) else: x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
# if __name__ == '__main__': # import torch # # # batch=1, num_frames=8, img_size=224*224, patch_size=16*16 # test_shape = (1, 196*8+1, 768) # fake_x = torch.rand(test_shape) # model = Block(dim=768, num_heads=12) # out = model(fake_x, b=1, t=8, w=int(224/16)) # assert(out.shape == test_shape)