# Original pytorch implementation by:
# 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
# - https://arxiv.org/abs/2010.11929
# 'How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers'
# - https://arxiv.org/abs/2106.10270
#
# Original code by / Copyright 2020, Ross Wightman.
# Modifications & additions by / Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.utils import model_zoo
from functools import partial
from collections import OrderedDict
from towhee.models.utils.init_vit_weights import init_vit_weights
from towhee.models.layers.patch_embed2d import PatchEmbed2D
from .vit_utils import get_configs
from .vit_block import Block
from towhee.models.layers.layers_with_relprop import LayerNorm, GELU, Linear, IndexSelect, Add
[docs]def compute_rollout_attention(all_layer_matrices, start_layer=0):
# adding residual consideration
num_tokens = all_layer_matrices[0].shape[1]
batch_size = all_layer_matrices[0].shape[0]
eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
# all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
# for i in range(len(all_layer_matrices))]
joint_attention = all_layer_matrices[start_layer]
for i in range(start_layer + 1, len(all_layer_matrices)):
joint_attention = all_layer_matrices[i].bmm(joint_attention)
return joint_attention
[docs]class VitModel(nn.Module):
"""
Vision Transformer Model
Args:
img_size (int): image height or width (height=width)
patch_size (int): patch height or width (height=width)
in_c (int): number of image channels
num_classes (int): number of classes
embed_dim (int): number of features
depth (int): number of blocks
num_heads (int): number of heads for Multi-Attention layer
mlp_ratio (float): mlp ratio
qkv_bias (bool): if add bias to qkv layer
qk_scale (float): number to scale qk
representation_size (int): size of representations
drop_ratio (float): drop rate of a block
attn_drop_ratio (float): drop rate of attention layer
drop_path_ratio (float): drop rate of drop_path layer
embed_layer: patch embedding layer
norm_layer: normalization layer
act_layer: activation layer
"""
[docs] def __init__(self,
img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
representation_size=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0.,
embed_layer=PatchEmbed2D, norm_layer=None, act_layer=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
norm_layer = norm_layer or partial(LayerNorm, eps=1e-6)
act_layer = act_layer or GELU
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_c, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", Linear(embed_dim, representation_size)), # pylint: disable=too-many-function-args
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # pylint: disable=too-many-function-args
# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(init_vit_weights)
self.pool = IndexSelect()
self.add = Add()
self.inp_grad = None
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
x = self.add([x, self.pos_embed])
x.register_hook(self.save_inp_grad)
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return self.pre_logits(x[:, 0])
[docs] def forward(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
x = self.add([x, self.pos_embed])
x.register_hook(self.save_inp_grad)
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
x = x.squeeze(1)
x = self.head(x)
return x
def save_inp_grad(self, grad):
self.inp_grad = grad
def get_inp_grad(self):
return self.inp_grad
def relprop(self, cam=None, method="transformer_attribution", start_layer=0, **kwargs):
# print(kwargs)
# print("conservation 1", cam.sum())
cam = self.head.relprop(cam, **kwargs)
cam = cam.unsqueeze(1)
cam = self.pool.relprop(cam, **kwargs)
cam = self.norm.relprop(cam, **kwargs)
for blk in reversed(self.blocks):
cam = blk.relprop(cam, **kwargs)
if method == "full":
(cam, _) = self.add.relprop(cam, **kwargs)
cam = cam[:, 1:]
cam = self.patch_embed.relprop(cam, **kwargs)
# sum on channels
cam = cam.sum(dim=1)
return cam
elif method == "rollout":
# cam rollout
attn_cams = []
for blk in self.blocks:
attn_heads = blk.attn.get_attn_cam().clamp(min=0)
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
attn_cams.append(avg_heads)
cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
cam = cam[:, 0, 1:]
return cam
# our method, method name grad is legacy
elif method in ("transformer_attribution", "grad"):
cams = []
for blk in self.blocks:
grad = blk.attn.get_attn_gradients()
cam = blk.attn.get_attn_cam()
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
cams.append(cam.unsqueeze(0))
rollout = compute_rollout_attention(cams, start_layer=start_layer)
cam = rollout[:, 0, 1:]
return cam
[docs]def create_model(
model_name: str = None,
pretrained: bool = False,
weights_path: str = None,
device: str = None,
**kwargs
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name is None:
if pretrained:
raise AssertionError("Fail to load pretrained model: no model name is specified.")
model = VitModel(**kwargs)
else:
configs = get_configs(model_name)
if "url" in configs:
url = configs["url"]
configs.pop("url")
model = VitModel(**configs)
if pretrained:
if weights_path:
state_dict = torch.load(weights_path)
elif url:
state_dict = model_zoo.load_url(url, map_location=torch.device(device))
else:
raise AssertionError("No model weights url or path is provided.")
model.load_state_dict(state_dict, strict=False)
model.eval()
return model