# Built on top of the original implementation at https://github.com/facebookresearch/SlowFast
#
# Modifications by Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Tuple, List
import numpy
import torch
from torch import nn
from torch.nn.init import trunc_normal_
from einops import rearrange
from towhee.models.layers.layers_with_relprop import GELU, Linear, Dropout, Conv2d, Conv3d, Softmax, Sigmoid, LayerNorm, \
Einsum, Add, Clone, MaxPool3d, IndexSelect
[docs]def find_most_h_w(all_layer_matrices):
shape_info = torch.vstack(
[torch.as_tensor([layer_matrix.shape[-2] for layer_matrix in all_layer_matrices]),
torch.as_tensor([layer_matrix.shape[-1] for layer_matrix in all_layer_matrices])])
return torch.mode(shape_info).values[0], torch.mode(shape_info).values[1]
[docs]def resize_last_dim_to_most(layer_matrix, last_dim_size):
if layer_matrix.shape[-1] == last_dim_size:
return layer_matrix
else:
cls_token, attn = layer_matrix[..., 1], layer_matrix[..., 1:]
if attn.shape[-1] > last_dim_size - 1:
assert attn.shape[-1] % (last_dim_size - 1) == 0
else:
assert (last_dim_size - 1) % attn.shape[-1] == 0
factor = (last_dim_size - 1) / attn.shape[-1]
attn = torch.nn.functional.interpolate(attn, size=last_dim_size - 1, mode="nearest")
attn = attn * factor
return torch.cat([cls_token.unsqueeze(dim=-1), attn], dim=-1)
[docs]def align_scale(all_layer_matrices):
most_attn_h, most_attn_w = find_most_h_w(all_layer_matrices)
aligned_layer_matrices = []
for layer_matrix in all_layer_matrices:
layer_matrix = resize_last_dim_to_most(layer_matrix, most_attn_w)
layer_matrix = layer_matrix.permute(0, 2, 1)
layer_matrix = resize_last_dim_to_most(layer_matrix, most_attn_h)
layer_matrix = layer_matrix.permute(0, 2, 1)
aligned_layer_matrices.append(layer_matrix)
return aligned_layer_matrices
[docs]def compute_rollout_attention(all_layer_matrices, start_layer=0):
all_layer_matrices = align_scale(all_layer_matrices)
num_tokens = all_layer_matrices[0].shape[1]
batch_size = all_layer_matrices[0].shape[0]
eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
joint_attention = all_layer_matrices[start_layer]
for i in range(start_layer + 1, len(all_layer_matrices)):
joint_attention = all_layer_matrices[i].bmm(joint_attention)
return joint_attention
[docs]class AttentionPool(nn.Module):
"""
A MLP block that contains two linear layers with a normalization layer.
the MLP block is used in a transformer model after the attention block.
::
Input
↓
Reshape
↓
Pool
↓
Reshape
↓
norm
Args:
thw_shape(List):
the shape of the input tensor (before flattening).
pool(Callable):
Pool operation that is applied to the input tensor.
If pool is None, return the input tensor.
has_cls_embed(bool):
whether the input tensor contains cls token. Pool operation excludes cls token.
norm(Callable):
Optional normalization operation applied to tensor after pool.
Returns:
tensor(torch.Tensor):
Input tensor after pool.
thw_shape(List[int]):
Output tensor shape (before flattening).
"""
[docs] def __init__(
self,
pool=None,
has_cls_embed=True,
norm=None
) -> None:
super().__init__()
self.pool = pool
self.has_cls_embed = has_cls_embed
self.norm = norm
self.tensor_dim = 4
[docs] def forward(self, x: torch.Tensor, thw_shape) -> Tuple[torch.Tensor, List[int]]:
if self.pool is None:
return x, thw_shape
self.tensor_dim = x.ndim
if self.tensor_dim == 4:
pass
elif self.tensor_dim == 3:
x = x.unsqueeze(1)
else:
raise NotImplementedError(f"Unsupported input dimension {x.shape}")
if self.has_cls_embed:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
b, n, _, c = x.shape
t, h, w = thw_shape
x = rearrange(x, "b n (t h w) c -> (b n) c t h w", b=b, n=n, c=c, t=t, h=h, w=w)
x = self.pool(x)
out_thw_shape = [x.shape[2], x.shape[3], x.shape[4]]
x = rearrange(x, "(b n) c t h w -> b n (t h w) c", b=b, n=n, c=c)
if self.has_cls_embed:
x = torch.cat((cls_tok, x), dim=2)
if self.norm is not None:
x = self.norm(x)
if self.tensor_dim == 4:
pass
else:
x = x.squeeze(1)
return x, out_thw_shape
def relprop(self, cam, **kwargs):
thw = kwargs.pop("thw")
t, h, w = thw[0], thw[1], thw[2]
if self.pool is None:
return cam, thw
if self.tensor_dim == 4:
pass
else:
cam = cam.unsqueeze(1)
if self.norm is not None:
cam = self.norm.relprop(cam, **kwargs)
if self.has_cls_embed:
cls_tok, cam = cam[:, :, :1, :], cam[:, :, 1:, :]
b, n, _, c = cam.shape
cam = rearrange(cam, " b n (t h w) c -> (b n) c t h w", t=t, h=h, w=w)
cam = self.pool.relprop(cam, **kwargs)
input_thw = cam.shape[-3:]
cam = rearrange(cam, "(b n) c t h w -> b n (t h w) c", b=b, n=n, c=c)
if self.has_cls_embed:
cam = torch.cat((cls_tok, cam), dim=2)
if self.tensor_dim == 4:
pass
elif self.tensor_dim == 3: # For the case tensor_dim == 3, for the case of side-way adding
cam = cam.squeeze(1)
else:
raise NotImplementedError(f"Unsupported input dimension {cam.shape}")
return cam, input_thw
[docs]def round_width(width, multiplier, min_width=1, divisor=1):
if not multiplier:
return width
width *= multiplier
min_width = min_width or divisor
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
if width_out < 0.9 * width:
width_out += divisor
return int(width_out)
[docs]class Mlp(nn.Module):
"""
Multi-layer perception module.
"""
[docs] def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=GELU,
drop_rate=0.0,
):
super().__init__()
self.drop_rate = drop_rate
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Linear(in_features, hidden_features) # pylint: disable=too-many-function-args
self.act = act_layer()
self.fc2 = Linear(hidden_features, out_features) # pylint: disable=too-many-function-args
if self.drop_rate > 0.0:
self.drop = Dropout(drop_rate) # pylint: disable=too-many-function-args
[docs] def forward(self, x):
x = self.fc1(x)
x = self.act(x)
if self.drop_rate > 0.0:
x = self.drop(x)
x = self.fc2(x)
if self.drop_rate > 0.0:
x = self.drop(x)
return x
def relprop(self, cam, **kwargs):
if self.drop_rate > 0.0:
cam = self.drop.relprop(cam, **kwargs)
cam = self.fc2.relprop(cam, **kwargs)
if self.drop_rate > 0.0:
cam = self.drop.relprop(cam, **kwargs)
cam = self.act.relprop(cam, **kwargs)
cam = self.fc1.relprop(cam, **kwargs)
return cam
[docs]class Permute(nn.Module):
[docs] def __init__(self, dims):
super().__init__()
self.dims = dims
[docs] def forward(self, x):
return x.permute(*self.dims)
[docs]def drop_path_func(x, drop_prob: float = 0.0, training: bool = False):
"""
Stochastic Depth per sample.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
mask.floor_() # binarize
output = x.div(keep_prob) * mask
return output
[docs]class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
[docs] def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
[docs] def forward(self, x):
return drop_path_func(x, self.drop_prob, self.training)
[docs]class PatchEmbed(nn.Module):
"""
PatchEmbed.
"""
[docs] def __init__(
self,
dim_in=3,
dim_out=768,
kernel=(1, 16, 16),
stride=(1, 4, 4),
padding=(1, 7, 7),
conv_2d=False,
):
super().__init__()
self.conv_2d = conv_2d
if self.conv_2d:
conv = Conv2d
else:
conv = Conv3d
self.proj = conv( # pylint: disable=too-many-function-args,unexpected-keyword-arg
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
self.output_t = 1
self.output_h = None
self.output_w = None
[docs] def forward(self, x):
x = self.proj(x)
if self.conv_2d:
self.output_h, self.output_w = x.shape[-2:]
else:
self.output_t, self.output_h, self.output_w = x.shape[-3:]
# B C (T) H W -> B (T)HW C
return x.flatten(2).transpose(1, 2)
def relprop(self, cam, **kwargs):
cam = cam.transpose(1, 2)
# B C (T)HW
if self.conv_2d:
cam = cam.reshape(cam.shape[0], cam.shape[1], self.output_h, self.output_w)
else:
cam = cam.reshape(cam.shape[0], cam.shape[1], self.output_t, self.output_h, self.output_w)
# B C (T) H W
return self.proj.relprop(cam, **kwargs)
[docs]class MultiScaleAttention(nn.Module):
"""
MultiScaleAttention
"""
[docs] def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
drop_rate=0.0,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
norm_layer=LayerNorm,
has_cls_embed=True,
# Options include `conv`, `avg`, and `max`.
mode="conv",
):
super().__init__()
self.drop_rate = drop_rate
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.has_cls_embed = has_cls_embed
padding_q = [int(q // 2) for q in kernel_q]
padding_kv = [int(kv // 2) for kv in kernel_kv]
self.qkv = Linear(dim, dim * 3, bias=qkv_bias) # pylint: disable=too-many-function-args,unexpected-keyword-arg
self.proj = Linear(dim, dim) # pylint: disable=too-many-function-args
# self.show_attn = nn.Softmax(dim=-1)
if drop_rate > 0.0:
self.proj_drop = nn.Dropout(drop_rate)
# A = Q*K^T
self.matmul1 = Einsum("bhid,bhjd->bhij")
# attn = A*V
self.matmul2 = Einsum("bhij,bhjd->bhid")
self.softmax = Softmax(dim=-1) # pylint: disable=unexpected-keyword-arg
# Skip pooling with kernel and stride size of (1, 1, 1).
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1:
kernel_q = ()
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1:
kernel_kv = ()
pool_q = pool_k = pool_v = norm_q = norm_k = norm_v = None
if mode == "avg":
pool_q = (
nn.AvgPool3d(kernel_q, stride_q, padding_q, ceil_mode=False)
if len(kernel_q) > 0
else None
)
pool_k = (
nn.AvgPool3d(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
pool_v = (
nn.AvgPool3d(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
elif mode == "max":
pool_q = (
nn.MaxPool3d(kernel_q, stride_q, padding_q, ceil_mode=False)
if len(kernel_q) > 0
else None
)
pool_k = (
nn.MaxPool3d(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
pool_v = (
nn.MaxPool3d(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
elif mode == "conv":
pool_q = (
Conv3d( # pylint: disable=too-many-function-args,unexpected-keyword-arg
head_dim,
head_dim,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=head_dim,
bias=False,
)
if len(kernel_q) > 0
else None
)
norm_q = norm_layer(head_dim) if len(kernel_q) > 0 else None
pool_k = (
Conv3d( # pylint: disable=too-many-function-args,disable=unexpected-keyword-arg
head_dim,
head_dim,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=head_dim,
bias=False,
)
if len(kernel_kv) > 0
else None
)
norm_k = norm_layer(head_dim) if len(kernel_kv) > 0 else None
pool_v = (
Conv3d( # pylint: disable=too-many-function-args,unexpected-keyword-arg
head_dim,
head_dim,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=head_dim,
bias=False,
)
if len(kernel_kv) > 0
else None
)
norm_v = norm_layer(head_dim) if len(kernel_kv) > 0 else None
else:
raise NotImplementedError(f"Unsupported model {mode}")
self.atn_pool_q = AttentionPool(
pool=pool_q,
# thw_shape=thw_shape,
has_cls_embed=self.has_cls_embed,
norm=norm_q,
)
self.atn_pool_k = AttentionPool(
pool=pool_k,
# thw_shape=thw_shape,
has_cls_embed=self.has_cls_embed,
norm=norm_k,
)
self.atn_pool_v = AttentionPool(
pool=pool_v,
# thw_shape=thw_shape,
has_cls_embed=self.has_cls_embed,
norm=norm_v,
)
self.attn_cam = None
self.attn = None
self.v = None
self.v_cam = None
self.attn_gradients = None
def get_attn(self):
return self.attn
def save_attn(self, attn):
self.attn = attn
def save_attn_cam(self, cam):
self.attn_cam = cam
def get_attn_cam(self):
return self.attn_cam
def get_v(self):
return self.v
def save_v(self, v):
self.v = v
def save_v_cam(self, cam):
self.v_cam = cam
def get_v_cam(self):
return self.v_cam
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
[docs] def forward(self, x, thw_shape):
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads)
self.save_v(v)
q, q_shape = self.atn_pool_q(q, thw_shape)
k, _ = self.atn_pool_k(k, thw_shape)
v, _ = self.atn_pool_v(v, thw_shape)
dots = self.matmul1([q, k]) * self.scale
attn = self.softmax(dots)
self.save_attn(attn)
attn.register_hook(self.save_attn_gradients)
x = self.matmul2([attn, v])
x = rearrange(x, "b h n d -> b n (h d)")
x = self.proj(x)
if self.drop_rate > 0.0:
x = self.proj_drop(x)
return x, q_shape
def relprop(self, cam, **kwargs):
thw = kwargs.pop("thw")
if self.drop_rate > 0.0:
cam = self.proj_drop.relprop(cam, **kwargs)
cam = self.proj.relprop(cam, **kwargs)
cam = rearrange(cam, "b n (h d) -> b h n d", h=self.num_heads)
# attn = A*V
(cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
cam1 /= 2
cam_v /= 2
self.save_v_cam(cam_v)
self.save_attn_cam(cam1)
cam1 = self.softmax.relprop(cam1, **kwargs)
# A = Q*K^T
(cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
cam_q /= 2
cam_k /= 2
t = thw[0]
kv_h = kv_w = int(math.sqrt(cam_k.shape[-2] - 1))
cam_q, input_thw = self.atn_pool_q.relprop(cam_q, thw=thw, **kwargs)
cam_k, _ = self.atn_pool_k.relprop(cam_k, thw=[t, kv_h, kv_w], **kwargs)
cam_v, _ = self.atn_pool_v.relprop(cam_v, thw=[t, kv_h, kv_w], **kwargs)
cam_qkv = rearrange([cam_q, cam_k, cam_v], "qkv b h n d -> b n (qkv h d)", qkv=3, h=self.num_heads)
return self.qkv.relprop(cam_qkv, **kwargs), input_thw
[docs]class MultiScaleBlock(nn.Module):
"""
MultiScaleBlock for mvit
"""
[docs] def __init__(
self,
dim,
dim_out,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop_rate=0.0,
drop_path=0.0,
act_layer=GELU,
norm_layer=LayerNorm,
up_rate=None,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
kernel_skip=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
stride_skip=(1, 1, 1),
mode="conv",
has_cls_embed=True,
):
super().__init__()
self.add1 = Add()
self.add2 = Add()
self.clone1 = Clone()
self.clone2 = Clone()
self.clone_norm = Clone()
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
padding_skip = [int(skip // 2) for skip in kernel_skip]
self.attn = MultiScaleAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=LayerNorm,
has_cls_embed=has_cls_embed,
mode=mode,
)
self.drop_path = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.has_cls_embed = has_cls_embed
if up_rate is not None and up_rate > 1:
mlp_dim_out = dim * up_rate
else:
mlp_dim_out = dim_out
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
out_features=mlp_dim_out,
act_layer=act_layer,
drop_rate=drop_rate,
)
if dim != dim_out:
self.proj = Linear(dim, dim_out) # pylint: disable=too-many-function-args
self.pool_skip = (
MaxPool3d( # pylint: disable=too-many-function-args,unexpected-keyword-arg
kernel_skip, stride_skip, padding_skip, ceil_mode=False
)
if len(kernel_skip) > 0
else None
)
self.atn_pool = AttentionPool(
pool=self.pool_skip,
# thw_shape=thw_shape,
has_cls_embed=self.has_cls_embed
)
[docs] def forward(self, x, thw_shape):
x1, x2 = self.clone1(x, 2)
x2, thw_shape_new = self.attn(self.norm1(x2), thw_shape)
x1, _ = self.atn_pool(x1, thw_shape)
x = self.add1([x1, self.drop_path(x2)])
if self.dim != self.dim_out:
x_norm = self.norm2(x)
x_norm2, x_norm1 = self.clone_norm(x_norm, 2)
x2 = self.mlp(x_norm2)
x1 = self.proj(x_norm1)
x = self.add2([x1, self.drop_path(x2)])
else:
x1, x2 = self.clone2(x, 2)
x_norm = self.norm2(x2)
x2 = self.mlp(x_norm)
x = self.add2([x1, self.drop_path(x2)])
return x, thw_shape_new
def relprop(self, cam, **kwargs):
thw = kwargs.pop("thw")
if self.dim != self.dim_out:
(cam1, cam2) = self.add2.relprop(cam, **kwargs)
cam_norm1 = self.proj.relprop(cam1, **kwargs)
cam_norm2 = self.mlp.relprop(cam2, **kwargs)
cam_norm = self.clone_norm.relprop((cam_norm2, cam_norm1), **kwargs)
cam = self.norm2.relprop(cam_norm, **kwargs)
else:
(cam1, cam2) = self.add2.relprop(cam, **kwargs)
cam_norm = self.mlp.relprop(cam2, **kwargs)
cam2 = self.norm2.relprop(cam_norm, **kwargs)
cam = self.clone2.relprop((cam1, cam2), **kwargs)
(cam1, cam2) = self.add1.relprop(cam, **kwargs)
cam1, _ = self.atn_pool.relprop(cam1, thw=thw, **kwargs)
cam2, input_thw = self.attn.relprop(cam2, thw=thw, **kwargs)
cam2 = self.norm1.relprop(cam2, **kwargs)
cam = self.clone1.relprop((cam1, cam2), **kwargs)
return cam, input_thw
[docs]class MViT(nn.Module):
"""
Multiscale Vision Transformers
Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer
https://arxiv.org/abs/2104.11227
"""
[docs] def __init__(self,
patch_2d: bool,
patch_stride: List,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
dropout_rate: float,
depth: int,
droppath_rate: float,
mode: str,
cls_embed_on: bool,
sep_pos_embed: bool,
norm: str,
patch_kernel: List,
patch_padding: List,
pool_q_kernel: List,
pool_kv_kernel: List,
pool_skip_kernel: List,
pool_q_stride: List,
pool_kv_stride: List,
pool_skip_stride: List,
dim_mul_arg: List,
head_mul_arg: List,
norm_stem: bool,
num_classes: int,
head_act: str,
train_crop_size: int,
test_crop_size: int,
num_frames: int,
input_channel_num: List,
):
super().__init__()
assert train_crop_size == test_crop_size
spatial_size = train_crop_size
temporal_size = num_frames
in_chans = input_channel_num[0]
use_2d_patch = patch_2d
self.patch_stride = patch_stride
if use_2d_patch:
self.patch_stride = [1] + self.patch_stride
self.drop_rate = dropout_rate
drop_path_rate = droppath_rate
self.cls_embed_on = cls_embed_on
self.sep_pos_embed = sep_pos_embed
if norm == "layernorm":
norm_layer = partial(LayerNorm, eps=1e-6)
else:
raise NotImplementedError("Only supports layernorm.")
self.num_classes = num_classes
self.patch_embed = PatchEmbed(
dim_in=in_chans,
dim_out=embed_dim,
kernel=patch_kernel,
stride=patch_stride,
padding=patch_padding,
conv_2d=use_2d_patch,
)
self.input_dims = [temporal_size, spatial_size, spatial_size]
assert self.input_dims[1] == self.input_dims[2]
self.patch_dims = [
self.input_dims[i] // self.patch_stride[i]
for i in range(len(self.input_dims))
]
num_patches = math.prod(self.patch_dims)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if self.cls_embed_on:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
pos_embed_dim = num_patches + 1
else:
pos_embed_dim = num_patches
if self.sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(
torch.zeros(
1, self.patch_dims[1] * self.patch_dims[2], embed_dim
)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, self.patch_dims[0], embed_dim)
)
self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
self.pos_embed = nn.Parameter(
torch.zeros(1, pos_embed_dim, embed_dim)
)
self.add = Add()
if self.drop_rate > 0.0:
self.pos_drop = nn.Dropout(p=self.drop_rate)
pool_q = pool_q_kernel
pool_kv = pool_kv_kernel
pool_skip = pool_skip_kernel
stride_q = pool_q_stride
stride_kv = pool_kv_stride
stride_skip = pool_skip_stride
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
if len(dim_mul_arg) > 1:
for k in dim_mul_arg:
dim_mul[k[0]] = k[1]
if len(head_mul_arg) > 1:
for k in head_mul_arg:
head_mul[k[0]] = k[1]
self.norm_stem = norm_layer(embed_dim) if norm_stem else None
self.blocks = nn.ModuleList()
for i in range(depth):
num_heads = round_width(num_heads, head_mul[i])
embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads)
dim_out = round_width(
embed_dim,
dim_mul[i + 1],
divisor=round_width(num_heads, head_mul[i + 1]),
)
self.blocks.append(
MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=self.drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
kernel_q=pool_q[i] if len(pool_q) > i else [],
kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
kernel_skip=pool_skip[i] if len(pool_skip) > i else [],
stride_q=stride_q[i] if len(stride_q) > i else [],
stride_kv=stride_kv[i] if len(stride_kv) > i else [],
stride_skip=stride_skip[i] if len(stride_skip) > i else [],
mode=mode,
has_cls_embed=self.cls_embed_on,
)
)
embed_dim = dim_out
self.norm = norm_layer(embed_dim)
self.index_select = IndexSelect()
self.head = TransformerBasicHead(
embed_dim,
num_classes,
dropout_rate=dropout_rate,
act_func=head_act,
)
if self.sep_pos_embed:
trunc_normal_(self.pos_embed_spatial, std=0.02)
trunc_normal_(self.pos_embed_temporal, std=0.02)
trunc_normal_(self.pos_embed_class, std=0.02)
else:
trunc_normal_(self.pos_embed, std=0.02)
if self.cls_embed_on:
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
self.out_thw = None
self.num_frames = num_frames
self.train_crop_size = train_crop_size
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x):
x = self.patch_embed(x)
t = self.num_frames // self.patch_stride[0]
h = self.train_crop_size // self.patch_stride[1]
w = self.train_crop_size // self.patch_stride[2]
b, _, _ = x.shape
if self.cls_embed_on:
cls_tokens = self.cls_token.expand(
b, -1, -1
)
x = torch.cat((cls_tokens, x), dim=1)
if self.sep_pos_embed:
pos_embed = self.pos_embed_spatial.repeat(
1, self.patch_dims[0], 1
) + torch.repeat_interleave(
self.pos_embed_temporal,
self.patch_dims[1] * self.patch_dims[2],
dim=1,
)
pos_embed_cls = torch.cat([self.pos_embed_class, pos_embed], 1)
x = self.add([x, pos_embed_cls])
else:
x = self.add([x, self.pos_embed])
if self.drop_rate:
x = self.pos_drop(x)
if self.norm_stem:
x = self.norm_stem(x)
thw = [t, h, w]
for blk in self.blocks:
x, thw = blk(x, thw)
self.out_thw = thw
x = self.norm(x)
if self.cls_embed_on:
x = self.index_select(x, dim=1, indices=torch.tensor(0, device=x.device))
x = x.squeeze(1)
else:
x = x.mean(1)
x = self.head(x)
return x
def relprop(self, cam, method="rollout", start_layer=0, **kwargs):
cam = self.head.relprop(cam, **kwargs)
if self.cls_embed_on:
cam = cam.unsqueeze(1)
cam = self.index_select.relprop(cam, **kwargs)
cam = self.norm.relprop(cam, **kwargs)
thw = self.out_thw
for blk in reversed(self.blocks):
cam, thw = blk.relprop(cam, thw=thw, **kwargs)
if method == "full":
(cam, _) = self.add.relprop(cam, **kwargs)
if self.cls_embed_on:
cam = cam[:, 1:]
cam = self.patch_embed.relprop(cam, **kwargs)
# sum on channels
cam = cam.sum(dim=1)
return cam
elif method == "rollout":
# cam rollout
attn_cams = []
for blk in self.blocks:
attn_heads = blk.attn.get_attn_cam().clamp(min=0)
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
attn_cams.append(avg_heads)
cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
cam = cam[:, 0, 1:]
return cam
elif method in ("transformer_attribution", "grad"):
cams = []
for blk in self.blocks:
grad = blk.attn.get_attn_gradients()
cam = blk.attn.get_attn_cam()
cam = cam[0].reshape(-1, cam.shape[-2], cam.shape[-1])
grad = grad[0].reshape(-1, grad.shape[-2], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
cams.append(cam.unsqueeze(0))
rollout = compute_rollout_attention(cams, start_layer=start_layer)
cam = rollout[:, 0, 1:]
return cam