# Original pytorch implementation by:
# 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
#       -
# 'How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers'
#       -
# Original code by / Copyright 2020, Ross Wightman.
# Modifications & additions by / Copyright 2021 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from torch.utils import model_zoo
from functools import partial
from collections import OrderedDict

from towhee.models.utils.init_vit_weights import init_vit_weights
from towhee.models.layers.patch_embed2d import PatchEmbed2D
from .vit_utils import get_configs
from .vit_block import Block
from towhee.models.layers.layers_with_relprop import LayerNorm, GELU, Linear, IndexSelect, Add

[docs]def compute_rollout_attention(all_layer_matrices, start_layer=0): # adding residual consideration num_tokens = all_layer_matrices[0].shape[1] batch_size = all_layer_matrices[0].shape[0] eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) # for i in range(len(all_layer_matrices))] joint_attention = all_layer_matrices[start_layer] for i in range(start_layer + 1, len(all_layer_matrices)): joint_attention = all_layer_matrices[i].bmm(joint_attention) return joint_attention
[docs]class VitModel(nn.Module): """ Vision Transformer Model Args: img_size (`int`): Image height or width (height=width). patch_size (`int`): Patch height or width (height=width). in_c (`int`): Number of image channels. num_classes (`int`): Number of classes. embed_dim (`int`): Number of features. depth (`int`): Number of blocks. num_heads (`int`): Number of heads for Multi-Attention layer. mlp_ratio (`float`): Mlp ratio. qkv_bias (`bool`): If add bias to qkv layer. qk_scale (`float`): Number to scale qk representation_size (`int`): Size of representations. drop_ratio (`float`): Drop rate of a block attn_drop_ratio (`float`): Drop rate of attention layer drop_path_ratio (`float`): Drop rate of drop_path layer embed_layer (`nn.Module`): Patch embedding layer. norm_layer (`nn.Module`): Normalization layer. act_layer(`nn.Module`): Activation layer. """
[docs] def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed2D, norm_layer=None, act_layer=None): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 norm_layer = norm_layer or partial(LayerNorm, eps=1e-6) act_layer = act_layer or GELU self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_c, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_ratio) dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim) # Representation layer if representation_size: self.has_logits = True self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ("fc", Linear(embed_dim, representation_size)), # pylint: disable=too-many-function-args ("act", nn.Tanh()) ])) else: self.has_logits = False self.pre_logits = nn.Identity() # Classifier head(s) self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # pylint: disable=too-many-function-args # Weight init nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) self.apply(init_vit_weights) self.pool = IndexSelect() self.add = Add() self.inp_grad = None
def forward_features(self, x): # [B, C, H, W] -> [B, num_patches, embed_dim] x = self.patch_embed(x) # [B, 196, 768] # [1, 1, 768] -> [B, 1, 768] cls_token = self.cls_token.expand(x.shape[0], -1, -1) x =, x), dim=1) # [B, 197, 768] x = self.add([x, self.pos_embed]) x.register_hook(self.save_inp_grad) x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return self.pre_logits(x[:, 0])
[docs] def forward(self, x): # [B, C, H, W] -> [B, num_patches, embed_dim] x = self.patch_embed(x) # [B, 196, 768] # [1, 1, 768] -> [B, 1, 768] cls_token = self.cls_token.expand(x.shape[0], -1, -1) x =, x), dim=1) # [B, 197, 768] x = self.add([x, self.pos_embed]) x.register_hook(self.save_inp_grad) x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device)) x = x.squeeze(1) x = self.head(x) return x
def save_inp_grad(self, grad): self.inp_grad = grad def get_inp_grad(self): return self.inp_grad def relprop(self, cam: torch.Tensor = None, method: str = "transformer_attribution", start_layer: int = 0, **kwargs): cam = self.head.relprop(cam, **kwargs) cam = cam.unsqueeze(1) cam = self.pool.relprop(cam, **kwargs) cam = self.norm.relprop(cam, **kwargs) for blk in reversed(self.blocks): cam = blk.relprop(cam, **kwargs) if method == "full": (cam, _) = self.add.relprop(cam, **kwargs) cam = cam[:, 1:] cam = self.patch_embed.relprop(cam, **kwargs) # sum on channels cam = cam.sum(dim=1) return cam elif method == "rollout": # cam rollout attn_cams = [] for blk in self.blocks: attn_heads = blk.attn.get_attn_cam().clamp(min=0) avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() attn_cams.append(avg_heads) cam = compute_rollout_attention(attn_cams, start_layer=start_layer) cam = cam[:, 0, 1:] return cam # our method, method name grad is legacy elif method in ("transformer_attribution", "grad"): cams = [] for blk in self.blocks: grad = blk.attn.get_attn_gradients() cam = blk.attn.get_attn_cam() cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) cam = grad * cam cam = cam.clamp(min=0).mean(dim=0) cams.append(cam.unsqueeze(0)) rollout = compute_rollout_attention(cams, start_layer=start_layer) cam = rollout[:, 0, 1:] return cam
[docs]def create_model( model_name: str = None, pretrained: bool = False, weights_path: str = None, device: str = None, **kwargs ) -> VitModel: """ Create ViT model. Args: model_name (`str`): ViT model name. pretrained (`bool`): Whether it is a pretrained model. weights_path (`str`): If pretrained, you can use your local weights. device (`str`): Cuda or cpu. Returns: VitModel """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if model_name is None: if pretrained: raise AssertionError("Fail to load pretrained model: no model name is specified.") model = VitModel(**kwargs) else: configs = get_configs(model_name) if "url" in configs: url = configs["url"] configs.pop("url") model = VitModel(**configs) if pretrained: if weights_path: state_dict = torch.load(weights_path) elif url: state_dict = model_zoo.load_url(url, map_location=torch.device(device)) else: raise AssertionError("No model weights url or path is provided.") model.load_state_dict(state_dict, strict=False) model.eval() return model