Source code for towhee.models.video_swin_transformer.video_swin_transformer
# original code from https://github.com/SwinTransformer/Video-Swin-Transformer
# modified by Zilliz.
import torch
from torch import nn
from torch.utils import model_zoo
from einops import rearrange
from towhee.models.layers.patch_embed3d import PatchEmbed3D
from towhee.models.layers.patch_merging3d import PatchMerging3D
from towhee.models.video_swin_transformer.video_swin_transformer_block import VideoSwinTransformerBlock
from towhee.models.video_swin_transformer import get_configs
from towhee.models.utils.init_vit_weights import init_vit_weights
from collections import OrderedDict
import logging
[docs]class VideoSwinTransformer(nn.Module):
"""
Video Swin Transformer.
Ze Liu, Jia Ning, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin, Han Hu
https://arxiv.org/pdf/2106.13230.pdf
Args:
pretrained (`str`):
Load pretrained weights. Default: None
pretrained2d (`bool`):
Load image pretrained weights. Default: False
patch_size (`tuple[int]`):
Patch size. Default: (4,4,4).
in_chans (`int)`:
Number of input image channels. Default: 3.
embed_dim (`int`):
Number of linear projection output channels. Default: 96.
depths (`tuple[int]`):
Depths of each Swin Transformer stage.
num_heads (`tuple[int]`):
Number of attention head of each stage.
window_size (`int`):
Window size. Default: 7.
mlp_ratio (`float`):
Ratio of mlp hidden dim to embedding dim. Default: 4.
num_classes (`int`):
the classification num.
qkv_bias (`bool`):
If True, add a learnable bias to query, key, value. Default: True
qk_scale (`float`):
Override default qk scale of head_dim ** -0.5 if set.
drop_rate (`float`):
Dropout rate.
attn_drop_rate (`float`):
Attention dropout rate. Default: 0.
drop_path_rate (`float`):
Stochastic depth rate. Default: 0.2.
norm_layer (`nn.Module`):
Normalization layer. Default: nn.LayerNorm.
patch_norm (`bool`):
If True, add normalization after patch embedding. Default: False.
frozen_stages (`int`):
Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (`bool`):
Use checkpoint.
stride (`tuple[int]`):
stride size for patch embed3d.
"""
[docs] def __init__(self,
pretrained=None,
pretrained2d=False,
patch_size=(4, 4, 4),
in_chans=3,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=(2, 7, 7),
mlp_ratio=4.,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
cls_dropout_ratio=0.4,
norm_layer=nn.LayerNorm,
patch_norm=False,
frozen_stages=-1,
use_checkpoint=False,
depth_mode=None,
depth_patch_embed_separate_params=True,
stride=None,
device="cpu"
):
super().__init__()
self.pretrained = pretrained
self.pretrained2d = pretrained2d
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.frozen_stages = frozen_stages
self.window_size = window_size
self.patch_size = patch_size
self.num_classes = num_classes
# split image into non-overlapping patches
self.patch_embed = PatchEmbed3D(
patch_size=patch_size, c=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None, stride=stride)
if depth_mode is not None:
assert depth_mode in ["separate_d_tokens", "summed_rgb_d_tokens", "rgbd"]
if depth_mode in ["separate_d_tokens", "summed_rgb_d_tokens"]:
depth_chans = 1
assert (
depth_patch_embed_separate_params
), "separate tokenization needs separate parameters"
if depth_mode == "separate_d_tokens":
raise NotImplementedError()
else:
assert depth_mode == "rgbd"
depth_chans = 4
self.depth_patch_embed_separate_params = depth_patch_embed_separate_params
if depth_patch_embed_separate_params:
self.depth_patch_embed = PatchEmbed3D(
patch_size=patch_size,
c=depth_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
stride=stride
)
else:
# share parameters with patch_embed
# delete the layer we built above
del self.patch_embed
assert depth_chans == 4
self.patch_embed = PatchEmbed3D(
patch_size=patch_size,
c=3,
embed_dim=embed_dim,
additional_variable_channels=[1],
norm_layer=norm_layer if self.patch_norm else None,
stride=stride
)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = VideoSwinTransformerBlock(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging3D if i_layer < self.num_layers-1 else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.num_features = int(embed_dim * 2**(self.num_layers-1))
# add a norm layer for each output
self.norm = norm_layer(self.num_features)
# use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels.
self.avg_pool3d = nn.AdaptiveAvgPool3d((1, 1, 1))
self.cls_dropout_ratio = cls_dropout_ratio
if self.cls_dropout_ratio != 0:
self.dropout = nn.Dropout(p=self.cls_dropout_ratio)
else:
self.dropout = None
self.fc_cls = nn.Linear(self.num_features, self.num_classes)
self.apply(init_vit_weights)
# if load pretrained weights
if self.pretrained not in ["", None]:
self.load_pretrained_weights(self.pretrained, device=device)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1:
self.pos_drop.eval()
for i in range(0, self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
[docs] def inflate_weights(self):
"""
Inflate the swin2d parameters to swin3d.
The differences between swin3d and swin2d mainly lie in an extra
axis. To utilize the pretrained parameters in 2d model,
the weight of swin2d models should be inflated to fit in the shapes of
the 3d counterpart.
"""
checkpoint = torch.load(self.pretrained, map_location="cpu")
state_dict = checkpoint["model"]
# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
state_dict["patch_embed.proj.weight"] =\
state_dict["patch_embed.proj.weight"].unsqueeze(2).repeat(1, 1, self.patch_size[0], 1, 1) \
/ self.patch_size[0]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
# pylint: disable=E1136
relative_position_bias_table_current = self.state_dict()[k]
l1, nh1 = relative_position_bias_table_pretrained.size()
l2, nh2 = relative_position_bias_table_current.size()
l2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)
wd = self.window_size[0]
if nh1 != nh2:
logging.info("Error in loading %s, passing", k)
else:
if l1 != l2:
s1 = int(l1 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nh1, s1, s1),
size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),
mode="bicubic")
relative_position_bias_table_pretrained =\
relative_position_bias_table_pretrained_resized.view(nh2, l2).permute(1, 0)
state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1, 1)
msg = self.load_state_dict(state_dict, strict=False)
logging.info(msg)
logging.info("=> loaded successfully %s", self.pretrained)
del checkpoint
torch.cuda.empty_cache()
[docs] def load_pretrained_weights(self, pretrained=None, device=None):
"""Initialize the weights from pretrained weights.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def map_state_dic(checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint["state_dict"].items():
name = k
if "backbone" in k or "cls_head" in k:
name = name[9:]
new_state_dict[name] = v
return new_state_dict
logging.info("load model from: %s", self.pretrained)
if self.pretrained2d:
# Inflate 2D model into 3D model.
self.inflate_weights()
else:
# Directly load 3D model.
checkpoint = model_zoo.load_url(pretrained, map_location=torch.device(device))
new_state_dict = map_state_dic(checkpoint)
self.load_state_dict(new_state_dict, strict=True)
def get_patch_embedding(self, x):
# x: B x C x T x H x W
assert x.ndim == 5
has_depth = x.shape[1] == 4
if has_depth:
if self.depth_mode in ["summed_rgb_d_tokens"]:
x_rgb = x[:, :3, ...]
x_d = x[:, 3:, ...]
x_d = self.depth_patch_embed(x_d)
x_rgb = self.patch_embed(x_rgb)
# sum the two sets of tokens
x = x_rgb + x_d
elif self.depth_mode == "rgbd":
if self.depth_patch_embed_separate_params:
x = self.depth_patch_embed(x)
else:
x = self.patch_embed(x)
else:
raise NotImplementedError()
else:
x = self.patch_embed(x)
return x
[docs] def forward(self, x):
x = self.get_patch_embedding(x)
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x.contiguous())
x = rearrange(x, "n c d h w -> n d h w c")
x = self.norm(x)
x = rearrange(x, "n d h w c -> n c d h w")
return x
def forward_features(self, x):
x = self.forward(x)
# [n, c, 1, 1, 1]
x = self.avg_pool3d(x)
if self.dropout is not None:
x = self.dropout(x)
# [n, c]
x = x.view(x.size(0), -1)
return x
[docs] def head(self, x):
"""
Warnings: need first load the forward_features function to get the features
Args:
x: x (torch.Tensor): The input data. [n, c]
Returns:
"""
# [n, num_classes]
cls_score = self.fc_cls(x)
return cls_score
[docs]def create_model(model_name: str = None, pretrained: bool = False,
device: str = None, **kwargs):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if pretrained:
if model_name is None:
raise AssertionError("Fail to load pretrained model: no model name is specified.")
if model_name:
model_configs = get_configs.configs(model_name)
model_configs = dict(pretrained=model_configs["pretrained"],
num_classes=model_configs["num_classes"],
embed_dim=model_configs["embed_dim"],
depths=model_configs["depths"],
num_heads=model_configs["num_heads"],
patch_size=model_configs["patch_size"],
window_size=model_configs["window_size"],
drop_path_rate=model_configs["drop_path_rate"],
patch_norm=model_configs["patch_norm"],
device=device)
if not pretrained:
model_configs["pretrained"] = None
model = VideoSwinTransformer(**model_configs)
else:
model = VideoSwinTransformer(**kwargs)
return model