Source code for towhee.models.tsm.temporal_shift

# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified by Zilliz.

import torch
from torch import nn
import torch.nn.functional as F
import torchvision

[docs]class TemporalShift(nn.Module): """ Args: """
[docs] def __init__(self, net, n_segment=3, n_div=8, inplace=False): super().__init__() self.net = net self.n_segment = n_segment self.fold_div = n_div self.inplace = inplace
[docs] def forward(self, x): x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) return self.net(x)
@staticmethod def shift(x, n_segment, fold_div=3, inplace=False): nt, c, h, w = x.size() n_batch = nt // n_segment x = x.view(n_batch, n_segment, c, h, w) fold = c // fold_div if inplace: # Due to some out of order error when performing parallel computing. # May need to write a CUDA kernel. raise NotImplementedError # out = InplaceShift.apply(x, fold) else: out = torch.zeros_like(x) out[:, :-1, :fold] = x[:, 1:, :fold] # shift left out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift return out.view(nt, c, h, w)
[docs]class InplaceShift(torch.autograd.Function): """ # Special thanks to @raoyongming for the help to this function """
[docs] @staticmethod def forward(ctx, input_x, fold): # not support higher order gradient # input = input.detach_() ctx.fold_ = fold n, t, _, h, w = input_x.size() buffer = input_x.data.new(n, t, fold, h, w).zero_() buffer[:, :-1] = input_x.data[:, 1:, :fold] input_x.data[:, :, :fold] = buffer buffer.zero_() buffer[:, 1:] = input_x.data[:, :-1, fold: 2 * fold] input_x.data[:, :, fold: 2 * fold] = buffer return input_x
[docs] @staticmethod def backward(ctx, grad_output): # grad_output = grad_output.detach_() fold = ctx.fold_ n, t, _, h, w = grad_output.size() buffer = grad_output.data.new(n, t, fold, h, w).zero_() buffer[:, 1:] = grad_output.data[:, :-1, :fold] grad_output.data[:, :, :fold] = buffer buffer.zero_() buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] grad_output.data[:, :, fold: 2 * fold] = buffer return grad_output, None
[docs]class TemporalPool(nn.Module): """ Args: """
[docs] def __init__(self, net, n_segment): super().__init__() self.net = net self.n_segment = n_segment
[docs] def forward(self, x): x = self.temporal_pool(x, n_segment=self.n_segment) return self.net(x)
@staticmethod def temporal_pool(x, n_segment): nt, c, h, w = x.size() n_batch = nt // n_segment x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) return x
[docs]def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): if temporal_pool: n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] else: n_segment_list = [n_segment] * 4 assert n_segment_list[-1] > 0 if isinstance(net, torchvision.models.ResNet): if place == 'block': def make_block_temporal(stage, this_segment): blocks = list(stage.children()) for i, b in enumerate(blocks): blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) return nn.Sequential(*(blocks)) net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) elif 'blockres' in place: n_round = 1 if len(list(net.layer3.children())) >= 23: n_round = 2 def make_block_temporal(stage, this_segment): blocks = list(stage.children()) for i, b in enumerate(blocks): if i % n_round == 0: blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) return nn.Sequential(*blocks) net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) else: raise NotImplementedError(place)
[docs]def make_temporal_pool(net, n_segment): if isinstance(net, torchvision.models.ResNet): net.layer2 = TemporalPool(net.layer2, n_segment) else: raise NotImplementedError