Source code for towhee.models.tsm.tsm

# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified by Zilliz.

import torch
from torch import nn
from torch.nn.init import normal_, constant_

import numpy as np
import torchvision

from towhee.models.utils.basic_ops import ConsensusModule
from towhee.models.layers.non_local import make_non_local
from towhee.models.tsm.mobilenet_v2 import mobilenet_v2, InvertedResidual
from towhee.models.tsm.temporal_shift import TemporalShift, make_temporal_shift
from towhee.models.tsm.config import _C

[docs]class TSN(nn.Module): """ Args: """
[docs] def __init__(self, num_class, num_segments, modality, base_model='resnet101', new_length=None, consensus_type='avg', before_softmax=True, dropout=0.8, img_feature_dim=256, crop_num=1, partial_bn=True, pretrain='imagenet', is_shift=False, shift_div=8, shift_place='blockres', fc_lr5=False, temporal_pool=False, non_local=False): super().__init__() self.modality = modality self.num_segments = num_segments self.reshape = True self.before_softmax = before_softmax self.dropout = dropout self.crop_num = crop_num self.consensus_type = consensus_type self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame self.pretrain = pretrain self.is_shift = is_shift self.shift_div = shift_div self.shift_place = shift_place self.base_model_name = base_model self.fc_lr5 = fc_lr5 self.temporal_pool = temporal_pool self.non_local = non_local if not before_softmax and consensus_type != 'avg': raise ValueError('Only avg consensus can be used after Softmax') if new_length is None: self.new_length = 1 if modality == 'RGB' else 5 else: self.new_length = new_length self._prepare_base_model(base_model) self._prepare_tsn(num_class) if self.modality == 'Flow': self.base_model = self._construct_flow_model(self.base_model) elif self.modality == 'RGBDiff': self.base_model = self._construct_diff_model(self.base_model) self.consensus = ConsensusModule(consensus_type) if not self.before_softmax: self.softmax = nn.Softmax() self._enable_pbn = partial_bn if partial_bn: self.partialbn(True)
def _prepare_tsn(self, num_class): feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features if self.dropout == 0: setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) self.new_fc = None else: setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) self.new_fc = nn.Linear(feature_dim, num_class) std = 0.001 if self.new_fc is None: normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) else: if hasattr(self.new_fc, 'weight'): normal_(self.new_fc.weight, 0, std) constant_(self.new_fc.bias, 0) def _prepare_base_model(self, base_model): if 'resnet' in base_model: self.base_model = getattr(torchvision.models, base_model)(bool(self.pretrain == 'imagenet')) if self.is_shift: make_temporal_shift(self.base_model, self.num_segments, n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool) if self.non_local: make_non_local(self.base_model, self.num_segments) self.base_model.last_layer_name = 'fc' self.input_size = 224 self.input_mean = [0.485, 0.456, 0.406] self.input_std = [0.229, 0.224, 0.225] self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) if self.modality == 'Flow': self.input_mean = [0.5] self.input_std = [np.mean(self.input_std)] elif self.modality == 'RGBDiff': self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length elif base_model == 'mobilenetv2': self.base_model = mobilenet_v2(bool(self.pretrain == 'imagenet')) self.base_model.last_layer_name = 'classifier' self.input_size = 224 self.input_mean = [0.485, 0.456, 0.406] self.input_std = [0.229, 0.224, 0.225] self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) if self.is_shift: for m in self.base_model.modules(): if isinstance(m, InvertedResidual) and len(m.conv) == 8 and m.use_res_connect: m.conv[0] = TemporalShift(m.conv[0], n_segment=self.num_segments, n_div=self.shift_div) if self.modality == 'Flow': self.input_mean = [0.5] self.input_std = [np.mean(self.input_std)] elif self.modality == 'RGBDiff': self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length else: raise ValueError('Unknown base model: {}'.format(base_model)) def partialbn(self, enable): self._enable_pbn = enable def head(self, base_out): if self.dropout > 0: base_out = self.new_fc(base_out) if not self.before_softmax: base_out = self.softmax(base_out) if self.reshape: if self.is_shift and self.temporal_pool: base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) else: base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) output = self.consensus(base_out) return output.squeeze(1)
[docs] def forward(self, input_x, no_reshape=False): base_out = self.forward_features(input_x, no_reshape) return self.head(base_out)
def forward_features(self, input_x, no_reshape=False): if not no_reshape: sample_len = (3 if self.modality == 'RGB' else 2) * self.new_length if self.modality == 'RGBDiff': sample_len = 3 * self.new_length input_x = self._get_diff(input_x) base_out = self.base_model(input_x.view((-1, sample_len) + input_x.size()[-2:])) else: base_out = self.base_model(input_x) return base_out def _get_diff(self, input_x, keep_rgb=False): input_c = 3 if self.modality in ['RGB', 'RGBDiff'] else 2 input_view = input_x.view((-1, self.num_segments, self.new_length + 1, input_c,) + input_x.size()[2:]) if keep_rgb: new_data = input_view.clone() else: new_data = input_view[:, :, 1:, :, :, :].clone() for x in reversed(list(range(1, self.new_length + 1))): if keep_rgb: new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] else: new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] return new_data def _construct_flow_model(self, base_model): # modify the convolution layers # Torch models are usually defined in a hierarchical way. # nn.modules.children() return all sub modules in a DFS manner modules = list(self.base_model.modules()) first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] conv_layer = modules[first_conv_idx] container = modules[first_conv_idx - 1] # modify parameters, assume the first blob contains the convolution kernels params = [x.clone() for x in conv_layer.parameters()] kernel_size = params[0].size() new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, bias=bool(len(params) == 2)) new_conv.weight.data = new_kernels if len(params) == 2: new_conv.bias.data = params[1].data # add bias if neccessary layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name # replace the first convlution layer setattr(container, layer_name, new_conv) return base_model def _construct_diff_model(self, base_model, keep_rgb=False): # modify the convolution layers # Torch models are usually defined in a hierarchical way. # nn.modules.children() return all sub modules in a DFS manner modules = list(self.base_model.modules()) first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] conv_layer = modules[first_conv_idx] container = modules[first_conv_idx - 1] # modify parameters, assume the first blob contains the convolution kernels params = [x.clone() for x in conv_layer.parameters()] kernel_size = params[0].size() if not keep_rgb: new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() else: new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 1) new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, bias=bool(len(params) == 2)) new_conv.weight.data = new_kernels if len(params) == 2: new_conv.bias.data = params[1].data # add bias if neccessary layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name # replace the first convolution layer setattr(container, layer_name, new_conv) return base_model
[docs]def create_model( model_name: str = 'tsm_k400_r50_seg8', pretrained: bool = False, weights_path: str = None, device: str = None, ): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' if model_name == 'tsm_k400_r50_seg8': model_config = _C.MODEL.TSMK400R50 elif model_name == 'tsm_k400_r50_seg16': model_config = _C.MODEL.TSMK400R50S16 elif model_name == 'tsm_some_r50_seg8': model_config = _C.MODEL.TSMSOMER50 elif model_name == 'tsm_somev2_r50_seg16': model_config = _C.MODEL.TSMSOMEV2R50SEG16 elif model_name == 'tsm_flow_k400_r50_seg8': model_config = _C.MODEL.TSMFlowK400R50 else: raise AttributeError(f'Invalid model_name {model_name}.') model = TSN( num_class = model_config.num_class, num_segments = model_config.num_segments, new_length = model_config.new_length, modality = model_config.input_modality, base_model = model_config.base_model, consensus_type = model_config.consensus_module, img_feature_dim = model_config.img_feature_dim, pretrain = model_config.pretrain, is_shift = model_config.is_shift, shift_div = model_config.shift_div, shift_place = model_config.shift_place, non_local = model_config.non_local, dropout = model_config.dropout_ratio, ) if pretrained: checkpoint = torch.load(weights_path, map_location=device) checkpoint = checkpoint['state_dict'] base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) model.load_state_dict(base_dict) model.to(device) return model