-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcam.py
98 lines (72 loc) · 3.15 KB
/
cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import cv2
import numpy as np
class CAM:
'''
Base Class
'''
def __init__(self, model, device, preprocess, layer_name=None):
if layer_name is None:
self.layer_name = self._get_layer_name(model)
else:
self.layer_name = layer_name
self.model = model.eval().to(device)
self.device = device
self.prep = preprocess
self.feature = {}
self._register_hook()
def get_heatmap(self, img):
pass
def _get_layer_name(self, model):
layer_name = None
for name, module in model.named_modules():
if hasattr(module, 'inplace'):
module.inplace = False
if isinstance(module, (nn.AdaptiveAvgPool2d, nn.AvgPool2d)):
layer_name = last_name
last_name = name
if layer_name is None:
raise ValueError('Defaultly use the last layer before global average '
'pooling to plot heatmap. However, There is no such '
'layer in this model.\n'
'So you need to specify the layer to plot heatmap.\n'
'Arg "layer_name" is the layer you should specify.\n'
'Generally, the layer is deeper, the interpretaton '
'is better.')
return layer_name
def _forward_hook(self, module, x, y):
self.feature['output'] = y
def _register_hook(self):
for name, module in self.model.named_modules():
if name == self.layer_name:
module.register_forward_hook(self._forward_hook)
break
else:
raise ValueError(f'There is no layer named "{self.layer_name}" in the model')
def _check(self, feature):
if feature.ndim != 4 or feature.shape[2] * feature.shape[3] == 1:
raise ValueError(f'Got invalid shape of feature map: {feature.shape}, '
'please specify another layer to plot heatmap.')
class EigenCAM(CAM):
def __init__(self, model, device, preprocess, layer_name=None):
super().__init__(model, device, preprocess, layer_name)
def get_heatmap(self, img):
with torch.no_grad():
tensor = self.prep(img)[None, ...].to(self.device)
output = self.model(tensor)
feature = self.feature['output']
self._check(feature)
_, _, vT = torch.linalg.svd(feature)
v1 = vT[:, :, 0, :][..., None, :]
cam = feature @ v1.repeat(1, 1, v1.shape[3], 1)
cam = cam.sum(1)
cam -= cam.min()
cam = cam / cam.max() * 255
cam = cam.cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
cam = cv2.resize(cam, img.size)
cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
if not isinstance(img, np.ndarray):
img = np.asarray(img)
overlay = np.uint8(0.6 * img + 0.4 * cam)
return output, overlay