-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_vrd_splits.py
119 lines (98 loc) · 2.79 KB
/
convert_vrd_splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import lzma
import os
import zipfile
import cv2
import numpy as np
from tqdm import tqdm
from telenet.config import get as tn_config
VRD_PATH = tn_config('paths.vrd')
with zipfile.ZipFile(os.path.join(VRD_PATH, 'vrd_json_dataset.zip'), 'r') as zf:
print('Loading VRD object names...')
with zf.open('objects.json') as f:
vrd_objnames = json.load(f)
print('Loading VRD relation names...')
with zf.open('predicates.json') as f:
vrd_relnames = json.load(f)
print('Loading VRD (train split)...')
with zf.open('annotations_train.json') as f:
vrd_train = json.load(f)
print('Loading VRD (test split)...')
with zf.open('annotations_test.json') as f:
vrd_test = json.load(f)
zfimg = zipfile.ZipFile(os.path.join(VRD_PATH, 'sg_dataset.zip'), 'r')
def conv_vrd_bb(vrdbb):
ymin,ymax,xmin,xmax = vrdbb
return [ xmin, ymin, xmax-xmin, ymax-ymin ]
def collect_vrd_info(data):
objs = []
objcache = {}
def get_obj_id(obj):
key = (obj['category'], obj['bbox'][0], obj['bbox'][1], obj['bbox'][2], obj['bbox'][3])
id = objcache.get(key, None)
if id is None:
id = len(objs)
objcache[key] = id
objs.append({
'v': obj['category'],
'bb': conv_vrd_bb(obj['bbox'])
})
return id
relmap = {}
for relobj in data:
src = get_obj_id(relobj["subject"])
dst = get_obj_id(relobj["object"])
rel = relobj["predicate"]
if src == dst:
print('WARNING:', vrd_objnames[relobj["subject"]["category"]], vrd_relnames[rel], 'itself')
continue
relset = relmap.get((src,dst), None)
if not relset:
relmap[(src,dst)] = relset = set()
relset.add(rel)
rels = []
for (src,dst),relset in relmap.items():
rels.append({
"n": len(relset),
"si": src,
"sv": objs[src]['v'],
"di": dst,
"dv": objs[dst]['v'],
"v": list(relset)
})
if len(rels) == 0:
return None
return {
'objs': objs,
'rels': rels
}
def process_vrd(vrd_split, splitname):
bad = 0
splitdata = []
for imgname,data in tqdm(vrd_split.items()):
data = collect_vrd_info(data)
if not data:
print(f'BAD: {imgname} has no relations, skipping')
bad += 1
continue
try:
imgraw = zfimg.read(f'sg_dataset/sg_{splitname}_images/{imgname}')
except:
print(f'BAD: {imgname} cannot be opened, skipping')
bad += 1
continue
imgraw = cv2.imdecode(np.frombuffer(imgraw, np.uint8), cv2.IMREAD_ANYCOLOR)
h,w,_ = imgraw.shape
data.update({
"id": imgname[:imgname.rfind('.')],
"w": w,
"h": h
})
splitdata.append(data)
with lzma.open(f'testdata/vrd-{splitname}.json.xz', 'wt', encoding='utf-8') as f:
json.dump(splitdata, f)
print(f'{bad} images dropped in {splitname} split')
with open('testdata/vrd-names.json', 'wt', encoding='utf-8') as f:
json.dump({ 'objs': vrd_objnames, 'attrs': [], 'rels': vrd_relnames }, f)
process_vrd(vrd_train, 'train')
process_vrd(vrd_test, 'test')