-
Notifications
You must be signed in to change notification settings - Fork 169
/
main.py
195 lines (164 loc) · 5.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# -*- coding: utf-8 -*-
"""
Created on 2019/8/4 上午9:53
@author: mick.yi
入口类
"""
import argparse
import os
import re
import cv2
import numpy as np
import torch
from skimage import io
from torch import nn
from torchvision import models
from interpretability.grad_cam import GradCAM, GradCamPlusPlus
from interpretability.guided_back_propagation import GuidedBackPropagation
def get_net(net_name, weight_path=None):
"""
根据网络名称获取模型
:param net_name: 网络名称
:param weight_path: 与训练权重路径
:return:
"""
pretrain = weight_path is None # 没有指定权重路径,则加载默认的预训练权重
if net_name in ['vgg', 'vgg16']:
net = models.vgg16(pretrained=pretrain)
elif net_name == 'vgg19':
net = models.vgg19(pretrained=pretrain)
elif net_name in ['resnet', 'resnet50']:
net = models.resnet50(pretrained=pretrain)
elif net_name == 'resnet101':
net = models.resnet101(pretrained=pretrain)
elif net_name in ['densenet', 'densenet121']:
net = models.densenet121(pretrained=pretrain)
elif net_name in ['inception']:
net = models.inception_v3(pretrained=pretrain)
elif net_name in ['mobilenet_v2']:
net = models.mobilenet_v2(pretrained=pretrain)
elif net_name in ['shufflenet_v2']:
net = models.shufflenet_v2_x1_0(pretrained=pretrain)
else:
raise ValueError('invalid network name:{}'.format(net_name))
# 加载指定路径的权重参数
if weight_path is not None and net_name.startswith('densenet'):
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(weight_path)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
net.load_state_dict(state_dict)
elif weight_path is not None:
net.load_state_dict(torch.load(weight_path))
return net
def get_last_conv_name(net):
"""
获取网络的最后一个卷积层的名字
:param net:
:return:
"""
layer_name = None
for name, m in net.named_modules():
if isinstance(m, nn.Conv2d):
layer_name = name
return layer_name
def prepare_input(image):
image = image.copy()
# 归一化
means = np.array([0.485, 0.456, 0.406])
stds = np.array([0.229, 0.224, 0.225])
image -= means
image /= stds
image = np.ascontiguousarray(np.transpose(image, (2, 0, 1))) # channel first
image = image[np.newaxis, ...] # 增加batch维
return torch.tensor(image, requires_grad=True)
def gen_cam(image, mask):
"""
生成CAM图
:param image: [H,W,C],原始图像
:param mask: [H,W],范围0~1
:return: tuple(cam,heatmap)
"""
# mask转为heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
heatmap = heatmap[..., ::-1] # gbr to rgb
# 合并heatmap到原始图像
cam = heatmap + np.float32(image)
return norm_image(cam), (heatmap * 255).astype(np.uint8)
def norm_image(image):
"""
标准化图像
:param image: [H,W,C]
:return:
"""
image = image.copy()
image -= np.max(np.min(image), 0)
image /= np.max(image)
image *= 255.
return np.uint8(image)
def gen_gb(grad):
"""
生guided back propagation 输入图像的梯度
:param grad: tensor,[3,H,W]
:return:
"""
# 标准化
grad = grad.data.numpy()
gb = np.transpose(grad, (1, 2, 0))
return gb
def save_image(image_dicts, input_image_name, network, output_dir):
prefix = os.path.splitext(input_image_name)[0]
for key, image in image_dicts.items():
io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)
def main(args):
# 输入
img = io.imread(args.image_path)
img = np.float32(cv2.resize(img, (224, 224))) / 255
inputs = prepare_input(img)
# 输出图像
image_dict = {}
# 网络
net = get_net(args.network, args.weight_path)
# Grad-CAM
layer_name = get_last_conv_name(net) if args.layer_name is None else args.layer_name
grad_cam = GradCAM(net, layer_name)
mask = grad_cam(inputs, args.class_id) # cam mask
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
grad_cam.remove_handlers()
# Grad-CAM++
grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
mask_plus_plus = grad_cam_plus_plus(inputs, args.class_id) # cam mask
image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
grad_cam_plus_plus.remove_handlers()
# GuidedBackPropagation
gbp = GuidedBackPropagation(net)
inputs.grad.zero_() # 梯度置零
grad = gbp(inputs)
gb = gen_gb(grad)
image_dict['gb'] = norm_image(gb)
# 生成Guided Grad-CAM
cam_gb = gb * mask[..., np.newaxis]
image_dict['cam_gb'] = norm_image(cam_gb)
save_image(image_dict, os.path.basename(args.image_path), args.network, args.output_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, default='resnet50',
help='ImageNet classification network')
parser.add_argument('--image-path', type=str, default='./examples/pic1.jpg',
help='input image path')
parser.add_argument('--weight-path', type=str, default=None,
help='weight path of the model')
parser.add_argument('--layer-name', type=str, default=None,
help='last convolutional layer name')
parser.add_argument('--class-id', type=int, default=None,
help='class id')
parser.add_argument('--output-dir', type=str, default='results',
help='output directory to save results')
arguments = parser.parse_args()
main(arguments)