towhee.trainer.utils.plot_utils.interpret_image_classification¶
- towhee.trainer.utils.plot_utils.interpret_image_classification(model: Module, image: Any, eval_transform: Compose, method: str, fig_size: Tuple = (10, 10), cmap: Any = 'OrRd', pred_label_idx: Optional[int] = None, titles: Optional[List] = None, **kwargs: Any)[source]¶
Use Captum to interpret the specified class of network output. Captum should be installed. :param model: Pytorch module. :type model: nn.Module :param image: The image before do transform.
It can be produced by either pytorch DataLoader or read by Image.open() using PIL.
- Parameters:
eval_transform (transforms.Compose) – Evaluation transform.
method (method) – It can be in [‘Occlusion’, ‘IntegratedGradients’, ‘GradientShap’, ‘Saliency’].
fig_size (Tuple) – Figure plotting size.
cmap (Any) – Matplotlib colormap.
pred_label_idx (int) – If None, use the predicted class automatically.
titles (List) – Plotted titles of the two axs in the figure.
**kwargs (Any) – Keyword Args.
- Returns:
- (tuple)
Prediction score and label idx. If input label is specified, the prediction score is None.