# Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Inspired by https://github.com/PeizeSun/SparseR-CNN/blob/dff4c43a9526a6d0d2480abc833e78a7c29ddb1a/detectron2/config/defaults.py
#
# Original code from https://github.com/Atze00/MoViNet-pytorch
#
# Modified by Zilliz.
from typing import Any, Callable, Optional, Tuple, Union
from einops import rearrange
import torch
from torch.nn.modules.utils import _triple
import torch.nn.functional as F
from torch import nn, Tensor
from towhee.models.utils.general_utils import make_divisible
from towhee.models.utils.causal_module import CausalModule
from towhee.models.layers.activations import HardSigmoid
from towhee.models.layers.conv_bn_activation import Conv2dBNActivation, Conv3DBNActivation
from towhee.models.layers.padding_functions import same_padding
from towhee.models.layers.tf_avgpool3d import TfAvgPool3D
from towhee.models.layers.temporal_cg_avgpool3d import TemporalCGAvgPool3D
[docs]class ConvBlock3D(CausalModule):
"""
ConvBlock3D
"""
[docs] def __init__(
self,
in_planes: int,
out_planes: int,
*,
kernel_size: Union[int, Tuple[int, int, int]],
tf_like: bool,
causal: bool,
conv_type: str,
padding: Union[int, Tuple[int, int, int]] = 0,
stride: Union[int, Tuple[int, int, int]] = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
bias: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
self.conv_2 = None
if tf_like:
# We neek odd kernel to have even padding
# and stride == 1 to precompute padding,
if kernel_size[0] % 2 == 0:
raise ValueError("tf_like supports only odd"
+ " kernels for temporal dimension")
padding = ((kernel_size[0]-1)//2, 0, 0)
if stride[0] != 1:
raise ValueError("illegal stride value, tf like supports"
+ " only stride == 1 for temporal dimension")
if stride[1] > kernel_size[1] or stride[2] > kernel_size[2]:
# these values are not tested so should be avoided
raise ValueError("tf_like supports only"
+ " stride <= of the kernel size")
if causal is True:
padding = (0, padding[1], padding[2])
if "2plus1d" not in conv_type and "3d" not in conv_type:
raise ValueError("only 2plus2d or 3d are "
+ "allowed as 3d convolutions")
if conv_type == "2plus1d":
self.conv_1 = Conv2dBNActivation(in_planes,
out_planes,
kernel_size=(kernel_size[1],
kernel_size[2]),
padding=(padding[1],
padding[2]),
stride=(stride[1], stride[2]),
activation_layer=activation_layer,
norm_layer=norm_layer,
bias=bias,
**kwargs)
if kernel_size[0] > 1:
self.conv_2 = Conv2dBNActivation(in_planes,
out_planes,
kernel_size=(kernel_size[0],
1),
padding=(padding[0], 0),
stride=(stride[0], 1),
activation_layer=activation_layer,
norm_layer=norm_layer,
bias=bias,
**kwargs)
elif conv_type == "3d":
self.conv_1 = Conv3DBNActivation(in_planes,
out_planes,
kernel_size=kernel_size,
padding=padding,
activation_layer=activation_layer,
norm_layer=norm_layer,
stride=stride,
bias=bias,
**kwargs)
self.padding = padding
self.kernel_size = kernel_size
self.dim_pad = self.kernel_size[0]-1
self.stride = stride
self.causal = causal
self.conv_type = conv_type
self.tf_like = tf_like
def _forward(self, x: Tensor) -> Tensor:
device = x.device
if self.dim_pad > 0 and self.conv_2 is None and self.causal is True:
x = self._cat_stream_buffer(x, device)
shape_with_buffer = x.shape
if self.conv_type == "2plus1d":
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv_1(x)
if self.conv_type == "2plus1d":
x = rearrange(x,
"(b t) c h w -> b c t h w",
t=shape_with_buffer[2])
if self.conv_2 is not None:
if self.dim_pad > 0 and self.causal is True:
x = self._cat_stream_buffer(x, device)
w = x.shape[-1]
x = rearrange(x, "b c t h w -> b c t (h w)")
x = self.conv_2(x)
x = rearrange(x, "b c t (h w) -> b c t h w", w=w)
return x
[docs] def forward(self, x: Tensor) -> Tensor:
if self.tf_like:
x = same_padding(x, x.shape[-2], x.shape[-1],
self.stride[-2], self.stride[-1],
self.kernel_size[-2], self.kernel_size[-1])
x = self._forward(x)
return x
def _cat_stream_buffer(self, x: Tensor, device: torch.device) -> Tensor:
if self.activation is None:
self._setup_activation(x.shape)
x = torch.cat((self.activation.to(device), x), 2)
self._save_in_activation(x)
return x
def _save_in_activation(self, x: Tensor) -> None:
assert self.dim_pad > 0
self.activation = x[:, :, -self.dim_pad:, ...].clone().detach()
def _setup_activation(self, input_shape: Tuple[float, ...]) -> None:
assert self.dim_pad > 0
self.activation = torch.zeros(*input_shape[:2], # type: ignore
self.dim_pad,
*input_shape[3:])
[docs]class SqueezeExcitation(nn.Module):
"""
SqueezeExcitation
"""
[docs] def __init__(self, input_channels: int,
activation_2: nn.Module,
activation_1: nn.Module,
conv_type: str,
causal: bool,
squeeze_factor: int = 4,
bias: bool = True) -> None:
super().__init__()
self.causal = causal
se_multiplier = 2 if causal else 1
squeeze_channels = make_divisible(input_channels
// squeeze_factor
* se_multiplier, 8)
self.temporal_cumualtive_gavg3d = TemporalCGAvgPool3D()
self.fc1 = ConvBlock3D(input_channels*se_multiplier,
squeeze_channels,
kernel_size=(1, 1, 1),
padding=0,
tf_like=False,
causal=causal,
conv_type=conv_type,
bias=bias)
self.activation_1 = activation_1()
self.activation_2 = activation_2()
self.fc2 = ConvBlock3D(squeeze_channels,
input_channels,
kernel_size=(1, 1, 1),
padding=0,
tf_like=False,
causal=causal,
conv_type=conv_type,
bias=bias)
def _scale(self, input_tensor: Tensor) -> Tensor:
if self.causal:
x_space = torch.mean(input_tensor, dim=[3, 4], keepdim=True)
scale = self.temporal_cumualtive_gavg3d(x_space)
scale = torch.cat((scale, x_space), dim=1)
else:
scale = F.adaptive_avg_pool3d(input_tensor, 1)
scale = self.fc1(scale)
scale = self.activation_1(scale)
scale = self.fc2(scale)
return self.activation_2(scale)
[docs] def forward(self, input_tensor: Tensor) -> Tensor:
scale = self._scale(input_tensor)
return scale * input_tensor
[docs]class BasicBneck(nn.Module):
"""
BasicBneck
"""
[docs] def __init__(self,
cfg: "CfgNode",
causal: bool,
tf_like: bool,
conv_type: str,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
assert isinstance(cfg.stride, tuple) is True
if not cfg.stride[0] == 1 or not 1 <= cfg.stride[1] <= 2 or not 1 <= cfg.stride[2] <= 2:
raise ValueError("illegal stride value")
self.res = None
layers = []
if cfg.expanded_channels != cfg.out_channels:
# expand
self.expand = ConvBlock3D(
in_planes=cfg.input_channels,
out_planes=cfg.expanded_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
causal=causal,
conv_type=conv_type,
tf_like=tf_like,
norm_layer=norm_layer,
activation_layer=activation_layer
)
# deepwise
self.deep = ConvBlock3D(
in_planes=cfg.expanded_channels,
out_planes=cfg.expanded_channels,
kernel_size=cfg.kernel_size,
padding=cfg.padding,
stride=cfg.stride,
groups=cfg.expanded_channels,
causal=causal,
conv_type=conv_type,
tf_like=tf_like,
norm_layer=norm_layer,
activation_layer=activation_layer
)
# SE
self.se = SqueezeExcitation(cfg.expanded_channels,
causal=causal,
activation_1=activation_layer,
activation_2=(nn.Sigmoid
if conv_type == "3d"
else HardSigmoid),
conv_type=conv_type
)
# project
self.project = ConvBlock3D(
cfg.expanded_channels,
cfg.out_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
causal=causal,
conv_type=conv_type,
tf_like=tf_like,
norm_layer=norm_layer,
activation_layer=nn.Identity
)
if not (cfg.stride == (1, 1, 1)
and cfg.input_channels == cfg.out_channels):
if cfg.stride != (1, 1, 1):
if tf_like:
layers.append(TfAvgPool3D())
else:
layers.append(nn.AvgPool3d((1, 3, 3),
stride=cfg.stride,
padding=cfg.padding_avg))
layers.append(ConvBlock3D(
in_planes=cfg.input_channels,
out_planes=cfg.out_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
norm_layer=norm_layer,
activation_layer=nn.Identity,
causal=causal,
conv_type=conv_type,
tf_like=tf_like
))
self.res = nn.Sequential(*layers)
# ReZero
self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)
[docs] def forward(self, input_tensor: Tensor) -> Tensor:
if self.res is not None:
residual = self.res(input_tensor)
else:
residual = input_tensor
if self.expand is not None:
x = self.expand(input_tensor)
else:
x = input_tensor
x = self.deep(x)
x = self.se(x)
x = self.project(x)
result = residual + self.alpha * x
return result