Source code for towhee.models.transrac.transrac

# Implementation of TransRAC in paper:
#   [TransRAC: Encoding Multi-scale Temporal Correlation with Transformers for Repetitive Action Counting]
#   (https://arxiv.org/abs/2204.01018)
#
# Inspired by official code from https://github.com/SvipRepetitionCounting/TransRAC
#
# Modifications & additions by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from towhee.models.layers.transformer_encoder import TransformerEncoder
from towhee.models.transrac.utils import DenseMap, SimilarityMatrix


[docs]class TransRAC(nn.Module): """ TransRAC model Args: - backbone (`str`): model name of pretrained backbone - num_frames (`int`): number of video frames - scales (`list`): a list of scale numbers Example: >>> import torch >>> from towhee.models import video_swin_transformer >>> from towhee.models.transrac import TransRAC >>> >>> dummy_video = torch.rand(1, 3, 4, 200, 200) # (bcthw) >>> >>> backbone = video_swin_transformer.create_model() # use default configs here >>> model = TransRAC(backbone=backbone, num_frames=4) >>> >>> out, out_matrix = model(dummy_video) >>> print(out.shape, out_matrix.shape) torch.Size([1, 4]), torch.Size([1, 12, 4, 4]) """
[docs] def __init__(self, backbone: str, num_frames: int = 64, scales: list = None): super().__init__() if scales is None: scales = [1, 4, 8] self.num_frames = num_frames # self.config = config self.scales = scales # Load pretrained Video Swin Transformer # self.backbone = create_model(model_name=backbone, pretrained=True) self.backbone = backbone self.replication_padding1 = nn.ConstantPad3d((0, 0, 0, 0, 1, 1), 0) self.replication_padding2 = nn.ConstantPad3d((0, 0, 0, 0, 2, 2), 0) self.replication_padding4 = nn.ConstantPad3d((0, 0, 0, 0, 4, 4), 0) self.conv3d = nn.Conv3d(in_channels=768, out_channels=512, kernel_size=3, padding=(3, 1, 1), dilation=(3, 1, 1)) self.bn1 = nn.BatchNorm3d(512) self.spatial_pooling = nn.MaxPool3d(kernel_size=(1, 7, 7)) self.sims = SimilarityMatrix() self.conv3x3 = nn.Conv2d(in_channels=4 * len(self.scales), # num_head*scale_num out_channels=32, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(32) self.dropout1 = nn.Dropout(0.25) self.input_projection = nn.Linear(self.num_frames * 32, 512) self.ln1 = nn.LayerNorm(512) self.trans_encoder = TransformerEncoder( d_model=512, n_head=4, dim_ff=512, num_layers=1, num_frames=self.num_frames, dropout=0.2) self.fc = DenseMap(512, 512, 256, 1)
[docs] def forward(self, x): # bcthw _, _, t, _, _ = x.shape assert t == self.num_frames multi_scales = [] for scale in self.scales: if scale == 4: x = self.replication_padding2(x) crops = [x[:, :, i:i + scale, :, :] for i in range(0, self.num_frames - scale + scale // 2 * 2, max(scale // 2, 1))] elif scale == 8: x = self.replication_padding4(x) crops = [x[:, :, i:i + scale, :, :] for i in range(0, self.num_frames - scale + scale // 2 * 2, max(scale // 2, 1))] else: crops = [x[:, :, i:i + 1, :, :] for i in range(0, self.num_frames)] slices = [] # feature extract with video SwinTransformer for crop in crops: crop = self.backbone(crop) slices.append(crop) x_scale = torch.cat(slices, dim=2) # -> (b, 768, f, size, size) x_scale = nn.functional.relu(self.bn1(self.conv3d(x_scale))) # -> (b, 512, f, 7, 7) x_scale = self.spatial_pooling(x_scale) # -> (b, 512, f, 1, 1) x_scale = x_scale.squeeze(3).squeeze(3) # -> (b, 512, f) x_scale = x_scale.transpose(1, 2) # -> (b, f, 512) # similarity matrix x_sims = nn.functional.relu(self.sims(x_scale, x_scale, x_scale)) # -> (b, 4, f, f) multi_scales.append(x_sims) x = torch.cat(multi_scales, dim=1) # -> (B, 4*scale_num, f, f) x_matrix = x # to return matrix x = nn.functional.relu(self.bn2(self.conv3x3(x))) # -> (b, 32, f, f) x = self.dropout1(x) x = x.permute(0, 2, 3, 1) # -> (b, f, f, 32) x = x.flatten(start_dim=2) # -> (b, f, 32*f) x = nn.functional.relu(self.input_projection(x)) # -> (b, f, 512) x = self.ln1(x) x = x.transpose(0, 1) # -> (f, b, 512) x = self.trans_encoder(x) x = x.transpose(0, 1) # -> (b, f, 512) x = self.fc(x) # -> (b, f, 1) x = x.squeeze(2) return x, x_matrix