-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathvalidate.py
180 lines (149 loc) · 6.69 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Validation script."""
import argparse
import json
import os
import numpy as np
import torch
from tqdm import tqdm
from alphapose.models import builder
from alphapose.utils.config import update_config
from alphapose.utils.metrics import evaluate_mAP
from alphapose.utils.transforms import (flip, flip_heatmap,
get_func_heatmap_to_coord)
from alphapose.utils.pPose_nms import oks_pose_nms
parser = argparse.ArgumentParser(description='AlphaPose Validate')
parser.add_argument('--cfg',
help='experiment configure file name',
default='./configs/coco/resnet/256x192_res50_lr1e-3_1x.yaml',
type=str)
parser.add_argument('--checkpoint',
help='checkpoint file name',
default='./pretrained_models/fast_res50_256x192.pth',
type=str)
parser.add_argument('--gpus',
help='gpus',
default='0',
type=str)
parser.add_argument('--batch',
help='validation batch size',
default=1,
type=int)
parser.add_argument('--flip-test',
default=False,
dest='flip_test',
help='flip test',
action='store_true')
parser.add_argument('--detector', dest='detector',
help='detector name', default="yolo")
opt = parser.parse_args()
cfg = update_config(opt.cfg)
gpus = [int(i) for i in opt.gpus.split(',')]
opt.gpus = [gpus[0]]
opt.device = torch.device("cuda:" + str(opt.gpus[0]) if opt.gpus[0] >= 0 else "cpu")
def validate(m, heatmap_to_coord, batch_size=20):
det_dataset = builder.build_dataset(cfg.DATASET.TEST, preset_cfg=cfg.DATA_PRESET, train=False, opt=opt)
eval_joints = det_dataset.EVAL_JOINTS
det_loader = torch.utils.data.DataLoader(
det_dataset, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=False)
kpt_json = []
m.eval()
norm_type = cfg.LOSS.get('NORM_TYPE', None)
hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
for inps, crop_bboxes, bboxes, img_ids, scores, imghts, imgwds in tqdm(det_loader, dynamic_ncols=True):
if isinstance(inps, list):
inps = [inp.cuda() for inp in inps]
else:
inps = inps.cuda()
output = m(inps)
if opt.flip_test:
if isinstance(inps, list):
inps_flip = [flip(inp).cuda() for inp in inps]
else:
inps_flip = flip(inps).cuda()
output_flip = flip_heatmap(m(inps_flip), det_dataset.joint_pairs, shift=True)
pred_flip = output_flip[:, eval_joints, :, :]
else:
output_flip = None
pred = output
assert pred.dim() == 4
pred = pred[:, eval_joints, :, :]
for i in range(output.shape[0]):
bbox = crop_bboxes[i].tolist()
pose_coords, pose_scores = heatmap_to_coord(
pred[i], bbox, hms_flip=pred_flip[i], hm_shape=hm_size, norm_type=norm_type)
keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
keypoints = keypoints.reshape(-1).tolist()
data = dict()
data['bbox'] = bboxes[i, 0].tolist()
data['image_id'] = int(img_ids[i])
data['area'] = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
# data['score'] = float(scores[i] + np.mean(pose_scores) + np.max(pose_scores))
data['score'] = float(scores[i])
data['category_id'] = 1
data['keypoints'] = keypoints
kpt_json.append(data)
kpt_json = oks_pose_nms(kpt_json)
with open('./exp/json/raw/validate_rcnn_kpt.json', 'w') as fid:
json.dump(kpt_json, fid)
res = evaluate_mAP('./exp/json/raw/validate_rcnn_kpt.json', ann_type='keypoints',
ann_file=os.path.join(cfg.DATASET.VAL.ROOT, cfg.DATASET.VAL.ANN))
return res
# return res['AP']
def validate_gt(m, cfg, heatmap_to_coord, batch_size=20):
gt_val_dataset = builder.build_dataset(cfg.DATASET.VAL, preset_cfg=cfg.DATA_PRESET, train=False)
eval_joints = gt_val_dataset.EVAL_JOINTS
gt_val_loader = torch.utils.data.DataLoader(
gt_val_dataset, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=False)
kpt_json = []
m.eval()
norm_type = cfg.LOSS.get('NORM_TYPE', None)
hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
for inps, labels, label_masks, img_ids, bboxes in tqdm(gt_val_loader, dynamic_ncols=True):
if isinstance(inps, list):
inps = [inp.cuda() for inp in inps]
else:
inps = inps.cuda()
output = m(inps)
if opt.flip_test:
if isinstance(inps, list):
inps_flip = [flip(inp).cuda() for inp in inps]
else:
inps_flip = flip(inps).cuda()
output_flip = flip_heatmap(m(inps_flip), gt_val_dataset.joint_pairs, shift=True)
pred_flip = output_flip[:, eval_joints, :, :]
else:
output_flip = None
pred = output
assert pred.dim() == 4
pred = pred[:, eval_joints, :, :]
for i in range(output.shape[0]):
bbox = bboxes[i].tolist()
pose_coords, pose_scores = heatmap_to_coord(
pred[i], bbox, hms_flip=pred_flip[i], hm_shape=hm_size, norm_type=norm_type)
keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
keypoints = keypoints.reshape(-1).tolist()
data = dict()
data['bbox'] = bboxes[i].tolist()
data['image_id'] = int(img_ids[i])
data['score'] = float(np.mean(pose_scores) + np.max(pose_scores))
data['category_id'] = 1
data['keypoints'] = keypoints
kpt_json.append(data)
with open('./exp/json/raw/validate_gt_kpt.json', 'w') as fid:
json.dump(kpt_json, fid)
res = evaluate_mAP('./exp/json/raw/validate_gt_kpt.json', ann_type='keypoints',
ann_file=os.path.join(cfg.DATASET.VAL.ROOT, cfg.DATASET.VAL.ANN))
return res
# return res['AP']
if __name__ == "__main__":
m = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
print(f'Loading model from {opt.checkpoint}...')
m.load_state_dict(torch.load(opt.checkpoint))
m = torch.nn.DataParallel(m, device_ids=gpus).cuda()
heatmap_to_coord = get_func_heatmap_to_coord(cfg)
with torch.no_grad():
gt_AP = validate_gt(m, cfg, heatmap_to_coord, opt.batch)
print("gt_AP = {}".format(gt_AP))
detbox_AP = validate(m, heatmap_to_coord, opt.batch)
print("detbox = {}".format(detbox_AP))
print('##### gt box: {} mAP | det box: {} mAP #####'.format(gt_AP, detbox_AP))