# Built on top of the original implementation at https://github.com/hila-chefer/Transformer-MM-Explainability
#
# Modifications by Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Callable
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch import nn
from towhee.models.visualization.transformer_visualization import _reshape_attr_and_get_heatmap
from towhee.models import clip
from towhee.models.clip import SimpleTokenizer
from towhee.trainer.utils.file_utils import is_captum_available, is_matplotlib_available
from towhee.utils.log import models_log
import torch
import numpy as np
[docs]def get_clip_relevance(model: nn.Module, pil_img: Image, text_list: List[str], device: str, vis_start_layer: int = 11,
text_start_layer: int = 11, transform: Callable = None, tokenize: Callable = None) -> Tuple:
"""
Get text relevance and image relevance from CLIP model.
Args:
model (`nn.Module`):
CLIP model to visualize.
pil_img (`Image`):
Input image.
text_list (`List[str]`):
List of text str.
device (`str`):
Device to use.
vis_start_layer (`int`):
Start layer for visualization.
text_start_layer (`int`):
Start layer for text.
transform (`Callable`):
Transform function for image.
tokenize (`Callable`):
Tokenize function for text.
Returns:
(`Tuple`):
text_relevance, image_relevance, text_tokens, img_tensor
"""
if transform is None:
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
transform = _transform(224)
if tokenize is None:
tokenize = clip.tokenize
img_tensor = transform(pil_img).unsqueeze(0).to(device)
text_tokens = tokenize(text_list).to(device)
batch_size = text_tokens.shape[0]
img_tensors = img_tensor.repeat(batch_size, 1, 1, 1)
logits_per_image, _ = model(img_tensors, text_tokens, device=device)
index = list(range(batch_size))
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to(device) * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
relevance = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
relevance = relevance.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(image_attn_blocks):
if i < vis_start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
relevance = relevance + torch.bmm(cam, relevance)
image_relevance = relevance[:, 0, 1:]
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
rel_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
rel_text = rel_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(text_attn_blocks):
if i < text_start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
rel_text = rel_text + torch.bmm(cam, rel_text)
text_relevance = rel_text
return text_relevance, image_relevance, text_tokens, img_tensor
[docs]def show_image_relevance(image_relevance: torch.Tensor, img_tensor: torch.Tensor, orig_image: Image):
"""
Show the image relevance heatmap.
Args:
image_relevance (`torch.Tensor`):
Image relevance.
img_tensor (`torch.Tensor`):
Transformed image tensor.
orig_image (`Image`):
Original input image.
"""
if not is_matplotlib_available():
models_log.warning('Matplotlib is not available.')
import matplotlib.pylab as plt # pylint: disable=import-outside-toplevel
_, axs = plt.subplots(1, 2)
axs[0].imshow(orig_image)
axs[0].axis('off')
vis = _reshape_attr_and_get_heatmap(image_relevance, img_tensor)
axs[1].imshow(vis)
axs[1].axis('off')
plt.show()
[docs]def show_heatmap_on_text(text: str, text_encoding: torch.Tensor, rel_text: torch.Tensor):
"""
Show the text relevance heatmap.
Args:
text (`str`):
Text to show.
text_encoding (`torch.Tensor`):
Tokenized text.
rel_text (`torch.Tensor`):
Text relevance.
"""
if not is_captum_available():
models_log.warning('You should install Captum first. Please run `pip install captum`.')
return None
from captum.attr import visualization as viz # pylint: disable=import-outside-toplevel
if not is_matplotlib_available():
models_log.warning('Matplotlib is not available.')
import matplotlib.pylab as plt # pylint: disable=import-outside-toplevel
tokenizer = SimpleTokenizer()
cls_idx = text_encoding.argmax(dim=-1)
rel_text = rel_text[cls_idx, 1:cls_idx]
text_scores = rel_text / rel_text.sum()
text_scores = text_scores.flatten()
text_tokens = tokenizer.encode(text)
text_tokens_decoded = [tokenizer.decode([a]) for a in text_tokens]
vis_data_records = [viz.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, text_tokens_decoded, 1)]
viz.visualize_text(vis_data_records)
plt.show()
[docs]def show_attention_for_clip(model: nn.Module, pil_img: Image, text_list: List[str], device: str = None):
"""
Show the attention for CLIP model. This function can show one image with multiple texts.
Args:
model (`nn.Module`):
CLIP model to show attention.
pil_img (`Image`):
Image to show.
text_list (`List[str]`):
Text list to show.
device (`str`):
Device to use.
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rel_text, rel_image, text_tokens, img_tensor = get_clip_relevance(model, pil_img, text_list, device)
if not is_matplotlib_available():
models_log.warning('Matplotlib is not available.')
import matplotlib.pylab as plt # pylint: disable=import-outside-toplevel
batch_size = len(text_list)
for i in range(batch_size):
show_heatmap_on_text(text_list[i], text_tokens[i], rel_text[i])
show_image_relevance(rel_image[i], img_tensor, orig_image=pil_img)
plt.show()