-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcoco2mediapipe.py
150 lines (131 loc) · 5.29 KB
/
coco2mediapipe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json
import os
import argparse
parser = argparse.ArgumentParser(description='Test mediapipe data.')
parser.add_argument('-j', help='JSON file', dest='json', required=True)
parser.add_argument('-o', help='path to output folder', dest='out',required=True)
args = parser.parse_args()
json_file = args.json
output = args.out
class COCO2MEDIAPIPE:
def __init__(self):
self._check_file_and_dir(json_file, output)
self.labels = json.load(open(json_file, 'r', encoding='utf-8'))
self.coco_id_name_map = self._categories()
self.coco_name_list = list(self.coco_id_name_map.values())
print("total images", len(self.labels['images']))
print("total categories", len(self.labels['categories']))
annotation = self.labels['annotations']
for ann in annotation:
kpts = ann['keypoints']
print("keypoints labels on ",kpts)
print("total labels", len(self.labels['annotations']))
def _check_file_and_dir(self, file_path, dir_path):
if not os.path.exists(file_path):
raise ValueError("file not found")
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def _categories(self):
categories = {}
for cls in self.labels['categories']:
categories[cls['id']] = cls['name']
return categories
def _load_images_info(self):
images_info = {}
for image in self.labels['images']:
id = image['id']
file_name = image['file_name']
if file_name.find('\\') > -1:
file_name = file_name[file_name.index('\\')+1:]
w = image['width']
h = image['height']
images_info[id] = (file_name, w, h)
return images_info
def _bbox_2_mediapipe(self, bbox, img_w, img_h):
x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
centerx = bbox[0] + w / 2
centery = bbox[1] + h / 2
dw = 1 / img_w
dh = 1 / img_h
centerx *= dw
w *= dw
centery *= dh
h *= dh
return centerx, centery, w, h
def _keypoints_2_mediapipekeypoint(self, keypoints, img_w, img_h):
kpts = []
kptsxindex = [0]
kptsyindex = [1]
for i in range(1,33):
kptsxindex.append(0 + (i * 3))
kptsyindex.append(1 + (i * 3))
#print(keypoints)
for i in range (len(keypoints)):
if i not in kptsxindex and i not in kptsyindex:
kpts.append(keypoints[i])
elif i not in kptsxindex:
dh = 1 / img_h
kpty = keypoints[i] * dh
kpts.append(kpty)
else:
dw = 1 / img_w
kptx = keypoints[i] * dw
kpts.append(kptx)
return kpts
def _convert_anno(self, images_info):
anno_dict = dict()
for anno in self.labels['annotations']:
bbox = anno['bbox']
image_id = anno['image_id']
category_id = anno['category_id']
kpts_pos = anno['keypoints']
image_info = images_info.get(image_id)
image_name = image_info[0]
img_w = image_info[1]
img_h = image_info[2]
#print(img_h)
mediapipe_box = self._bbox_2_mediapipe(bbox, img_w, img_h)
mediapipekpts = self._keypoints_2_mediapipekeypoint(kpts_pos, img_w, img_h)
#print(mediapipekpts)
anno_info = (image_name, category_id, mediapipe_box, mediapipekpts)
anno_infos = anno_dict.get(image_id)
if not anno_infos:
anno_dict[image_id] = [anno_info]
else:
anno_infos.append(anno_info)
anno_dict[image_id] = anno_infos
return anno_dict
def save_classes(self):
sorted_classes = list(map(lambda x: x['name'], sorted(self.labels['categories'], key=lambda x: x['id'])))
print('coco names', sorted_classes)
with open('coco.names', 'w', encoding='utf-8') as f:
for cls in sorted_classes:
f.write(cls + '\n')
f.close()
def coco2mediapipe(self):
print("loading image info...")
images_info = self._load_images_info()
print("loading done, total images", len(images_info))
print("start converting...")
anno_dict = self._convert_anno(images_info)
print("converting done, total labels", len(anno_dict))
print("saving txt file...")
self._save_txt(anno_dict)
print("saving done")
def _save_txt(self, anno_dict):
for k, v in anno_dict.items():
file_name = v[0][0].split(".")[0] + ".txt"
with open(os.path.join(output, file_name), 'w', encoding='utf-8') as f:
#print(k, v)
for obj in v:
cat_name = self.coco_id_name_map.get(obj[1])
category_id = self.coco_name_list.index(cat_name)
box = ['{:.6f}'.format(x) for x in obj[2]]
box = ' '.join(box)
kpt = ['{:.6f}'.format(x) for x in obj[3]]
kpt = ' '.join(kpt)
line = str(category_id) + ' ' + box + ' ' + str(kpt)
f.write(line + '\n')
if __name__ == '__main__':
c2y = COCO2mediapipe()
c2y.coco2mediapipe()