# Copyright 2022 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# Inspired by
# Original code from
# Modified by Zilliz.

from typing import Any, Callable, Optional, Tuple, Union
from einops import rearrange

import torch
from torch.nn.modules.utils import _triple
import torch.nn.functional as F
from torch import nn, Tensor

from towhee.models.utils.general_utils import make_divisible
from towhee.models.utils.causal_module import CausalModule

from towhee.models.layers.activations import HardSigmoid
from towhee.models.layers.conv_bn_activation import Conv2dBNActivation, Conv3DBNActivation
from towhee.models.layers.padding_functions import same_padding
from towhee.models.layers.tf_avgpool3d import TfAvgPool3D
from towhee.models.layers.temporal_cg_avgpool3d import TemporalCGAvgPool3D

[docs]class ConvBlock3D(CausalModule): """ ConvBlock3D """
[docs] def __init__( self, in_planes: int, out_planes: int, *, kernel_size: Union[int, Tuple[int, int, int]], tf_like: bool, causal: bool, conv_type: str, padding: Union[int, Tuple[int, int, int]] = 0, stride: Union[int, Tuple[int, int, int]] = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None, bias: bool = False, **kwargs: Any, ) -> None: super().__init__() kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding) self.conv_2 = None if tf_like: # We neek odd kernel to have even padding # and stride == 1 to precompute padding, if kernel_size[0] % 2 == 0: raise ValueError("tf_like supports only odd" + " kernels for temporal dimension") padding = ((kernel_size[0]-1)//2, 0, 0) if stride[0] != 1: raise ValueError("illegal stride value, tf like supports" + " only stride == 1 for temporal dimension") if stride[1] > kernel_size[1] or stride[2] > kernel_size[2]: # these values are not tested so should be avoided raise ValueError("tf_like supports only" + " stride <= of the kernel size") if causal is True: padding = (0, padding[1], padding[2]) if "2plus1d" not in conv_type and "3d" not in conv_type: raise ValueError("only 2plus2d or 3d are " + "allowed as 3d convolutions") if conv_type == "2plus1d": self.conv_1 = Conv2dBNActivation(in_planes, out_planes, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), activation_layer=activation_layer, norm_layer=norm_layer, bias=bias, **kwargs) if kernel_size[0] > 1: self.conv_2 = Conv2dBNActivation(in_planes, out_planes, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), activation_layer=activation_layer, norm_layer=norm_layer, bias=bias, **kwargs) elif conv_type == "3d": self.conv_1 = Conv3DBNActivation(in_planes, out_planes, kernel_size=kernel_size, padding=padding, activation_layer=activation_layer, norm_layer=norm_layer, stride=stride, bias=bias, **kwargs) self.padding = padding self.kernel_size = kernel_size self.dim_pad = self.kernel_size[0]-1 self.stride = stride self.causal = causal self.conv_type = conv_type self.tf_like = tf_like
def _forward(self, x: Tensor) -> Tensor: device = x.device if self.dim_pad > 0 and self.conv_2 is None and self.causal is True: x = self._cat_stream_buffer(x, device) shape_with_buffer = x.shape if self.conv_type == "2plus1d": x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv_1(x) if self.conv_type == "2plus1d": x = rearrange(x, "(b t) c h w -> b c t h w", t=shape_with_buffer[2]) if self.conv_2 is not None: if self.dim_pad > 0 and self.causal is True: x = self._cat_stream_buffer(x, device) w = x.shape[-1] x = rearrange(x, "b c t h w -> b c t (h w)") x = self.conv_2(x) x = rearrange(x, "b c t (h w) -> b c t h w", w=w) return x
[docs] def forward(self, x: Tensor) -> Tensor: if self.tf_like: x = same_padding(x, x.shape[-2], x.shape[-1], self.stride[-2], self.stride[-1], self.kernel_size[-2], self.kernel_size[-1]) x = self._forward(x) return x
def _cat_stream_buffer(self, x: Tensor, device: torch.device) -> Tensor: if self.activation is None: self._setup_activation(x.shape) x =, x), 2) self._save_in_activation(x) return x def _save_in_activation(self, x: Tensor) -> None: assert self.dim_pad > 0 self.activation = x[:, :, -self.dim_pad:, ...].clone().detach() def _setup_activation(self, input_shape: Tuple[float, ...]) -> None: assert self.dim_pad > 0 self.activation = torch.zeros(*input_shape[:2], # type: ignore self.dim_pad, *input_shape[3:])
[docs]class SqueezeExcitation(nn.Module): """ SqueezeExcitation """
[docs] def __init__(self, input_channels: int, activation_2: nn.Module, activation_1: nn.Module, conv_type: str, causal: bool, squeeze_factor: int = 4, bias: bool = True) -> None: super().__init__() self.causal = causal se_multiplier = 2 if causal else 1 squeeze_channels = make_divisible(input_channels // squeeze_factor * se_multiplier, 8) self.temporal_cumualtive_gavg3d = TemporalCGAvgPool3D() self.fc1 = ConvBlock3D(input_channels*se_multiplier, squeeze_channels, kernel_size=(1, 1, 1), padding=0, tf_like=False, causal=causal, conv_type=conv_type, bias=bias) self.activation_1 = activation_1() self.activation_2 = activation_2() self.fc2 = ConvBlock3D(squeeze_channels, input_channels, kernel_size=(1, 1, 1), padding=0, tf_like=False, causal=causal, conv_type=conv_type, bias=bias)
def _scale(self, input_tensor: Tensor) -> Tensor: if self.causal: x_space = torch.mean(input_tensor, dim=[3, 4], keepdim=True) scale = self.temporal_cumualtive_gavg3d(x_space) scale =, x_space), dim=1) else: scale = F.adaptive_avg_pool3d(input_tensor, 1) scale = self.fc1(scale) scale = self.activation_1(scale) scale = self.fc2(scale) return self.activation_2(scale)
[docs] def forward(self, input_tensor: Tensor) -> Tensor: scale = self._scale(input_tensor) return scale * input_tensor
[docs]class BasicBneck(nn.Module): """ BasicBneck """
[docs] def __init__(self, cfg: "CfgNode", causal: bool, tf_like: bool, conv_type: str, norm_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() assert isinstance(cfg.stride, tuple) is True if not cfg.stride[0] == 1 or not 1 <= cfg.stride[1] <= 2 or not 1 <= cfg.stride[2] <= 2: raise ValueError("illegal stride value") self.res = None layers = [] if cfg.expanded_channels != cfg.out_channels: # expand self.expand = ConvBlock3D( in_planes=cfg.input_channels, out_planes=cfg.expanded_channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=activation_layer ) # deepwise self.deep = ConvBlock3D( in_planes=cfg.expanded_channels, out_planes=cfg.expanded_channels, kernel_size=cfg.kernel_size, padding=cfg.padding, stride=cfg.stride, groups=cfg.expanded_channels, causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=activation_layer ) # SE = SqueezeExcitation(cfg.expanded_channels, causal=causal, activation_1=activation_layer, activation_2=(nn.Sigmoid if conv_type == "3d" else HardSigmoid), conv_type=conv_type ) # project self.project = ConvBlock3D( cfg.expanded_channels, cfg.out_channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), causal=causal, conv_type=conv_type, tf_like=tf_like, norm_layer=norm_layer, activation_layer=nn.Identity ) if not (cfg.stride == (1, 1, 1) and cfg.input_channels == cfg.out_channels): if cfg.stride != (1, 1, 1): if tf_like: layers.append(TfAvgPool3D()) else: layers.append(nn.AvgPool3d((1, 3, 3), stride=cfg.stride, padding=cfg.padding_avg)) layers.append(ConvBlock3D( in_planes=cfg.input_channels, out_planes=cfg.out_channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), norm_layer=norm_layer, activation_layer=nn.Identity, causal=causal, conv_type=conv_type, tf_like=tf_like )) self.res = nn.Sequential(*layers) # ReZero self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)
[docs] def forward(self, input_tensor: Tensor) -> Tensor: if self.res is not None: residual = self.res(input_tensor) else: residual = input_tensor if self.expand is not None: x = self.expand(input_tensor) else: x = input_tensor x = self.deep(x) x = x = self.project(x) result = residual + self.alpha * x return result