Source code for towhee.models.movinet.movinet

# Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Code inspired by https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv2.html
# https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html
#
# Original code from https://github.com/Atze00/MoViNet-pytorch
#
# Modified by Zilliz.

from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from towhee.models.movinet.movinet_block import ConvBlock3D, BasicBneck
from towhee.models.utils.causal_module import CausalModule
from towhee.models.layers.activations import Swish
from towhee.models.layers.temporal_cg_avgpool3d import TemporalCGAvgPool3D
from towhee.models.movinet.config import _C

[docs]class MoViNet(nn.Module): """ Args: causal(`string`): Causal mode. pretrained(`bool`): Pretrained models. If pretrained is True, num_classes is set to 600, conv_type is set to "3d" if causal is False, "2plus1d" if causal is True tf_like is set to True. num_classes(`int`): Number of classes for classifcation. conv_type(`string`): Type of convolution either 3d or 2plus1d tf_like(`bool`): Tf_like behaviour, basically same padding for convolutions. """
[docs] def __init__(self, cfg: "CfgNode", causal: bool = True, pretrained: bool = False, num_classes: int = 600, conv_type: str = "3d", tf_like: bool = False ) -> None: super().__init__() if pretrained: tf_like = True num_classes = 600 conv_type = "2plus1d" if causal else "3d" blocks_dic = OrderedDict() norm_layer = nn.BatchNorm3d if conv_type == "3d" else nn.BatchNorm2d activation_layer = Swish if conv_type == "3d" else nn.Hardswish # conv1 self.conv1 = ConvBlock3D( in_planes=cfg.conv1.input_channels, out_planes=cfg.conv1.out_channels, kernel_size=cfg.conv1.kernel_size, stride=cfg.conv1.stride, padding=cfg.conv1.padding, causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=activation_layer ) # blocks for i, block in enumerate(cfg.blocks): for j, basicblock in enumerate(block): blocks_dic[f"b{i}_l{j}"] = BasicBneck(basicblock, causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=activation_layer ) self.blocks = nn.Sequential(blocks_dic) # conv7 self.conv7 = ConvBlock3D( in_planes=cfg.conv7.input_channels, out_planes=cfg.conv7.out_channels, kernel_size=cfg.conv7.kernel_size, stride=cfg.conv7.stride, padding=cfg.conv7.padding, causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=activation_layer ) # pool self.classifier = nn.Sequential( # dense9 ConvBlock3D(cfg.conv7.out_channels, cfg.dense9.hidden_dim, kernel_size=(1, 1, 1), tf_like=tf_like, causal=causal, conv_type=conv_type, bias=True), Swish(), nn.Dropout(p=0.2, inplace=True), # dense10d ConvBlock3D(cfg.dense9.hidden_dim, num_classes, kernel_size=(1, 1, 1), tf_like=tf_like, causal=causal, conv_type=conv_type, bias=True), ) if causal: self.cgap = TemporalCGAvgPool3D() if pretrained: if causal: if cfg.name not in ["A0", "A1", "A2"]: raise ValueError("Only A0,A1,A2 streaming" + "networks are available pretrained") state_dict = (torch.hub .load_state_dict_from_url(cfg.stream_weights)) else: state_dict = torch.hub.load_state_dict_from_url(cfg.weights) self.load_state_dict(state_dict) else: self.apply(self._weight_init) self.causal = causal
def avg(self, x: Tensor) -> Tensor: if self.causal: avg = F.adaptive_avg_pool3d(x, (x.shape[2], 1, 1)) avg = self.cgap(avg)[:, :, -1:] else: avg = F.adaptive_avg_pool3d(x, 1) return avg @staticmethod def _weight_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.GroupNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) def forward_features(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.blocks(x) x = self.conv7(x) x = self.avg(x) return x def head(self, x: Tensor) -> Tensor: x = self.classifier(x) x = x.flatten(1) return x def _forward_impl(self, x: Tensor) -> Tensor: x = self.forward_features(x) x = self.head(x) return x
[docs] def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x)
@staticmethod def _clean_activation_buffers(m): if issubclass(type(m), CausalModule): m.reset_activation() def clean_activation_buffers(self) -> None: self.apply(self._clean_activation_buffers)
[docs]def create_model( model_name: str = "movineta0", pretrained: bool = False, causal: bool = False, device: str = None, ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if model_name == "movineta0": model_config = _C.MODEL.MoViNetA0 elif model_name == "movineta1": model_config = _C.MODEL.MoViNetA1 elif model_name == "movineta2": model_config = _C.MODEL.MoViNetA2 elif model_name == "movineta3": model_config = _C.MODEL.MoViNetA3 elif model_name == "movineta4": model_config = _C.MODEL.MoViNetA4 elif model_name == "movineta5": model_config = _C.MODEL.MoViNetA5 else: raise AttributeError(f"Invalid model_name {model_name}.") model = MoViNet( cfg = model_config, causal = causal, pretrained = pretrained, num_classes = 600, conv_type = "3d", tf_like = False ) model.to(device) return model