# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified by Zilliz.
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
[docs]class TemporalShift(nn.Module):
"""
Args:
"""
[docs] def __init__(self, net, n_segment=3, n_div=8, inplace=False):
super().__init__()
self.net = net
self.n_segment = n_segment
self.fold_div = n_div
self.inplace = inplace
[docs] def forward(self, x):
x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
return self.net(x)
@staticmethod
def shift(x, n_segment, fold_div=3, inplace=False):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w)
fold = c // fold_div
if inplace:
# Due to some out of order error when performing parallel computing.
# May need to write a CUDA kernel.
raise NotImplementedError
# out = InplaceShift.apply(x, fold)
else:
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
[docs]class InplaceShift(torch.autograd.Function):
"""
# Special thanks to @raoyongming for the help to this function
"""
[docs] @staticmethod
def forward(ctx, input_x, fold):
# not support higher order gradient
# input = input.detach_()
ctx.fold_ = fold
n, t, _, h, w = input_x.size()
buffer = input_x.data.new(n, t, fold, h, w).zero_()
buffer[:, :-1] = input_x.data[:, 1:, :fold]
input_x.data[:, :, :fold] = buffer
buffer.zero_()
buffer[:, 1:] = input_x.data[:, :-1, fold: 2 * fold]
input_x.data[:, :, fold: 2 * fold] = buffer
return input_x
[docs] @staticmethod
def backward(ctx, grad_output):
# grad_output = grad_output.detach_()
fold = ctx.fold_
n, t, _, h, w = grad_output.size()
buffer = grad_output.data.new(n, t, fold, h, w).zero_()
buffer[:, 1:] = grad_output.data[:, :-1, :fold]
grad_output.data[:, :, :fold] = buffer
buffer.zero_()
buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
grad_output.data[:, :, fold: 2 * fold] = buffer
return grad_output, None
[docs]class TemporalPool(nn.Module):
"""
Args:
"""
[docs] def __init__(self, net, n_segment):
super().__init__()
self.net = net
self.n_segment = n_segment
[docs] def forward(self, x):
x = self.temporal_pool(x, n_segment=self.n_segment)
return self.net(x)
@staticmethod
def temporal_pool(x, n_segment):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
return x
[docs]def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
if temporal_pool:
n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
else:
n_segment_list = [n_segment] * 4
assert n_segment_list[-1] > 0
if isinstance(net, torchvision.models.ResNet):
if place == 'block':
def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
return nn.Sequential(*(blocks))
net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
elif 'blockres' in place:
n_round = 1
if len(list(net.layer3.children())) >= 23:
n_round = 2
def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
for i, b in enumerate(blocks):
if i % n_round == 0:
blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
return nn.Sequential(*blocks)
net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
else:
raise NotImplementedError(place)
[docs]def make_temporal_pool(net, n_segment):
if isinstance(net, torchvision.models.ResNet):
net.layer2 = TemporalPool(net.layer2, n_segment)
else:
raise NotImplementedError