# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified by Zilliz.
import torch
from torch import nn
from torch.nn.init import normal_, constant_
import numpy as np
import torchvision
from towhee.models.utils.basic_ops import ConsensusModule
from towhee.models.layers.non_local import make_non_local
from towhee.models.tsm.mobilenet_v2 import mobilenet_v2, InvertedResidual
from towhee.models.tsm.temporal_shift import TemporalShift, make_temporal_shift
from towhee.models.tsm.config import _C
[docs]class TSN(nn.Module):
"""
Args:
"""
[docs] def __init__(self,
num_class,
num_segments,
modality,
base_model='resnet101',
new_length=None,
consensus_type='avg',
before_softmax=True,
dropout=0.8,
img_feature_dim=256,
crop_num=1,
partial_bn=True,
pretrain='imagenet',
is_shift=False,
shift_div=8,
shift_place='blockres',
fc_lr5=False,
temporal_pool=False,
non_local=False):
super().__init__()
self.modality = modality
self.num_segments = num_segments
self.reshape = True
self.before_softmax = before_softmax
self.dropout = dropout
self.crop_num = crop_num
self.consensus_type = consensus_type
self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame
self.pretrain = pretrain
self.is_shift = is_shift
self.shift_div = shift_div
self.shift_place = shift_place
self.base_model_name = base_model
self.fc_lr5 = fc_lr5
self.temporal_pool = temporal_pool
self.non_local = non_local
if not before_softmax and consensus_type != 'avg':
raise ValueError('Only avg consensus can be used after Softmax')
if new_length is None:
self.new_length = 1 if modality == 'RGB' else 5
else:
self.new_length = new_length
self._prepare_base_model(base_model)
self._prepare_tsn(num_class)
if self.modality == 'Flow':
self.base_model = self._construct_flow_model(self.base_model)
elif self.modality == 'RGBDiff':
self.base_model = self._construct_diff_model(self.base_model)
self.consensus = ConsensusModule(consensus_type)
if not self.before_softmax:
self.softmax = nn.Softmax()
self._enable_pbn = partial_bn
if partial_bn:
self.partialbn(True)
def _prepare_tsn(self, num_class):
feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
self.new_fc = None
else:
setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
self.new_fc = nn.Linear(feature_dim, num_class)
std = 0.001
if self.new_fc is None:
normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
else:
if hasattr(self.new_fc, 'weight'):
normal_(self.new_fc.weight, 0, std)
constant_(self.new_fc.bias, 0)
def _prepare_base_model(self, base_model):
if 'resnet' in base_model:
self.base_model = getattr(torchvision.models, base_model)(bool(self.pretrain == 'imagenet'))
if self.is_shift:
make_temporal_shift(self.base_model, self.num_segments,
n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)
if self.non_local:
make_non_local(self.base_model, self.num_segments)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)
if self.modality == 'Flow':
self.input_mean = [0.5]
self.input_std = [np.mean(self.input_std)]
elif self.modality == 'RGBDiff':
self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
elif base_model == 'mobilenetv2':
self.base_model = mobilenet_v2(bool(self.pretrain == 'imagenet'))
self.base_model.last_layer_name = 'classifier'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)
if self.is_shift:
for m in self.base_model.modules():
if isinstance(m, InvertedResidual) and len(m.conv) == 8 and m.use_res_connect:
m.conv[0] = TemporalShift(m.conv[0], n_segment=self.num_segments, n_div=self.shift_div)
if self.modality == 'Flow':
self.input_mean = [0.5]
self.input_std = [np.mean(self.input_std)]
elif self.modality == 'RGBDiff':
self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
else:
raise ValueError('Unknown base model: {}'.format(base_model))
def partialbn(self, enable):
self._enable_pbn = enable
def head(self, base_out):
if self.dropout > 0:
base_out = self.new_fc(base_out)
if not self.before_softmax:
base_out = self.softmax(base_out)
if self.reshape:
if self.is_shift and self.temporal_pool:
base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
else:
base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
output = self.consensus(base_out)
return output.squeeze(1)
[docs] def forward(self, input_x, no_reshape=False):
base_out = self.forward_features(input_x, no_reshape)
return self.head(base_out)
def forward_features(self, input_x, no_reshape=False):
if not no_reshape:
sample_len = (3 if self.modality == 'RGB' else 2) * self.new_length
if self.modality == 'RGBDiff':
sample_len = 3 * self.new_length
input_x = self._get_diff(input_x)
base_out = self.base_model(input_x.view((-1, sample_len) + input_x.size()[-2:]))
else:
base_out = self.base_model(input_x)
return base_out
def _get_diff(self, input_x, keep_rgb=False):
input_c = 3 if self.modality in ['RGB', 'RGBDiff'] else 2
input_view = input_x.view((-1, self.num_segments, self.new_length + 1, input_c,) + input_x.size()[2:])
if keep_rgb:
new_data = input_view.clone()
else:
new_data = input_view[:, :, 1:, :, :, :].clone()
for x in reversed(list(range(1, self.new_length + 1))):
if keep_rgb:
new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
else:
new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
return new_data
def _construct_flow_model(self, base_model):
# modify the convolution layers
# Torch models are usually defined in a hierarchical way.
# nn.modules.children() return all sub modules in a DFS manner
modules = list(self.base_model.modules())
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
# modify parameters, assume the first blob contains the convolution kernels
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=bool(len(params) == 2))
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data # add bias if neccessary
layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
# replace the first convlution layer
setattr(container, layer_name, new_conv)
return base_model
def _construct_diff_model(self, base_model, keep_rgb=False):
# modify the convolution layers
# Torch models are usually defined in a hierarchical way.
# nn.modules.children() return all sub modules in a DFS manner
modules = list(self.base_model.modules())
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
# modify parameters, assume the first blob contains the convolution kernels
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
if not keep_rgb:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
else:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
1)
new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=bool(len(params) == 2))
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data # add bias if neccessary
layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
# replace the first convolution layer
setattr(container, layer_name, new_conv)
return base_model
[docs]def create_model(
model_name: str = 'tsm_k400_r50_seg8',
pretrained: bool = False,
weights_path: str = None,
device: str = None,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_name == 'tsm_k400_r50_seg8':
model_config = _C.MODEL.TSMK400R50
elif model_name == 'tsm_k400_r50_seg16':
model_config = _C.MODEL.TSMK400R50S16
elif model_name == 'tsm_some_r50_seg8':
model_config = _C.MODEL.TSMSOMER50
elif model_name == 'tsm_somev2_r50_seg16':
model_config = _C.MODEL.TSMSOMEV2R50SEG16
elif model_name == 'tsm_flow_k400_r50_seg8':
model_config = _C.MODEL.TSMFlowK400R50
else:
raise AttributeError(f'Invalid model_name {model_name}.')
model = TSN(
num_class = model_config.num_class,
num_segments = model_config.num_segments,
new_length = model_config.new_length,
modality = model_config.input_modality,
base_model = model_config.base_model,
consensus_type = model_config.consensus_module,
img_feature_dim = model_config.img_feature_dim,
pretrain = model_config.pretrain,
is_shift = model_config.is_shift,
shift_div = model_config.shift_div,
shift_place = model_config.shift_place,
non_local = model_config.non_local,
dropout = model_config.dropout_ratio,
)
if pretrained:
checkpoint = torch.load(weights_path, map_location=device)
checkpoint = checkpoint['state_dict']
base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
'base_model.classifier.bias': 'new_fc.bias',
}
for k, v in replace_dict.items():
if k in base_dict:
base_dict[v] = base_dict.pop(k)
model.load_state_dict(base_dict)
model.to(device)
return model