forked from marshuang80/penet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_cams.py
executable file
·120 lines (96 loc) · 4.85 KB
/
get_cams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import moviepy.editor as mpy
import numpy as np
import os
import torch
import util
from args import TestArgParser
from cams import GradCAM
from cams import GuidedBackPropagation
from data_loader import CTDataLoader
from saver import ModelSaver
from collections import defaultdict
def get_cams(args):
print('Loading model...')
model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
model = model.to(args.device)
args.start_epoch = ckpt_info['epoch'] + 1
print('Last layer in model.features is named "{}"...'.format([k for k in model.module.encoders._modules.keys()][-1]))
print('Extracting feature maps from layer named "{}"...'.format(args.target_layer))
grad_cam = GradCAM(model, args.device, is_binary=True, is_3d=True)
gbp = GuidedBackPropagation(model, args.device, is_binary=True, is_3d=True)
num_generated = 0
data_loader = CTDataLoader(args, phase=args.phase, is_training=False)
study_idx_dict = {}
for inputs, target_dict in data_loader:
probs, idx = grad_cam.forward(inputs)
grad_cam.backward(idx=idx[0]) # Just take top prediction
cam = grad_cam.get_cam(args.target_layer)
labels = target_dict['is_abnormal']
if labels.item() == 0:
# Keep going until we get an aneurysm study
print('Skipping a normal example...')
continue
print('Generating CAM...')
with torch.set_grad_enabled(True):
probs, idx = grad_cam.forward(inputs)
grad_cam.backward(idx=idx[0]) # Just take top prediction
cam = grad_cam.get_cam(args.target_layer)
guided_backprop = None
if args.use_gbp:
inputs2 = torch.autograd.Variable(inputs, requires_grad=True)
probs2, idx2 = gbp.forward(inputs2)
gbp.backward(idx=idx2[0])
guided_backprop = np.squeeze(gbp.generate())
print('Overlaying CAM...')
print(cam.shape)
new_cam = util.resize(cam, inputs[0])
print(new_cam.shape)
input_np = util.un_normalize(inputs[0], args.img_format, data_loader.dataset.pixel_dict)
input_np = np.transpose(input_np, (1, 2, 3, 0))
input_frames = list(input_np)
input_normed = np.float32(input_np) / 255
cam_frames = list(util.add_heat_map(input_normed, new_cam))
gbp_frames = None
if args.use_gbp:
gbp_np = util.normalize_to_image(guided_backprop * new_cam)
gbp_frames = []
for dim in range(gbp_np.shape[0]):
slice_ = gbp_np[dim, :, :]
gbp_frames.append(slice_[..., None])
# Write to a GIF file
output_path_input = os.path.join(os.path.join(args.cam_dir, '{}_{}_input_fn_intermountain.gif'.format(study_num, study_count)))
output_path_cam = os.path.join(args.cam_dir, '{}_{}_cam_fn_intermountain.gif'.format(study_num, study_count))
output_path_combined = os.path.join(args.cam_dir, '{}_{}_combined_fn_intermountain.gif'.format(study_num, study_count))
print('Writing set {}/{} of CAMs to {}...'.format(num_generated + 1, args.num_cams, args.cam_dir))
input_clip = mpy.ImageSequenceClip(input_frames, fps=4)
input_clip.write_gif(output_path_input, verbose=False)
cam_clip = mpy.ImageSequenceClip(cam_frames, fps=4)
cam_clip.write_gif(output_path_cam, verbose=False)
combined_clip = mpy.clips_array([[input_clip, cam_clip]])
combined_clip.write_gif(output_path_combined, verbose=False)
if args.use_gbp:
output_path_gcam = os.path.join(args.cam_dir, 'gbp_{}.gif'.format(num_generated + 1))
gbp_clip = mpy.ImageSequenceClip(gbp_frames, fps=4)
gbp_clip.write_gif(output_path_gcam, verbose=False)
num_generated += 1
if num_generated == args.num_cams:
return
if __name__ == '__main__':
parser = TestArgParser()
parser.parser.add_argument('--test_2d', action='store_true', help='Test CAMs on a pretrained 2D VGG net.')
parser.parser.add_argument('--target_layer', type=str, default='module.encoders.3',
help='Name of target layer for extracting feature maps.')
parser.parser.add_argument('--cam_dir', type=str, default='data/', help='Directory to write CAM outputs.')
parser.parser.add_argument('--num_cams', type=int, default=1, help='Number of CAMs to generate.')
parser.parser.add_argument('--use_gbp', type=util.str_to_bool, default=False,
help='If True, use guided backprop. Else just regular CAMs.')
# Hard-coded settings (tested for PE dataset only)
args_ = parser.parse_args()
args_.do_hflip = False
args_.do_vflip = False
args_.do_center_pe = True
args_.do_jitter = False
args_.do_rotate = False
args_.use_hem = False
args_.abnormal_prob = 1.
get_cams(args_)