towhee.trainer.utils.plot_utils.interpret_image_classification

towhee.trainer.utils.plot_utils.interpret_image_classification(model: Module, image: Any, eval_transform: Compose, method: str, fig_size: Tuple = (10, 10), cmap: Any = 'OrRd', pred_label_idx: Optional[int] = None, titles: Optional[List] = None, **kwargs: Any)[source]

Use Captum to interpret the specified class of network output. Captum should be installed. :param model: Pytorch module. :type model: nn.Module :param image: The image before do transform.

It can be produced by either pytorch DataLoader or read by Image.open() using PIL.

Parameters:
  • eval_transform (transforms.Compose) – Evaluation transform.

  • method (method) – It can be in [‘Occlusion’, ‘IntegratedGradients’, ‘GradientShap’, ‘Saliency’].

  • fig_size (Tuple) – Figure plotting size.

  • cmap (Any) – Matplotlib colormap.

  • pred_label_idx (int) – If None, use the predicted class automatically.

  • titles (List) – Plotted titles of the two axs in the figure.

  • **kwargs (Any) – Keyword Args.

Returns:

(tuple)

Prediction score and label idx. If input label is specified, the prediction score is None.