forked from shekkizh/FCN.tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathread_KittiData.py
66 lines (52 loc) · 2.05 KB
/
read_KittiData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
__author__ = 'oflucas'
import numpy as np
import os
import random
from six.moves import cPickle as pickle
from tensorflow.python.platform import gfile
import glob
import TensorflowUtils as utils
DATA_URL = ''
TRAIN_DATA_DIR = 'Data_zoo/kitti_road/data_road/training'
# sub dirs: annotations_lane annotations_road calib gt_image_2 image_2
def read_dataset(data_dir=TRAIN_DATA_DIR):
"""
go to training data dir and retrive:
training_records, validation_records
where records are:
[{'image': image_path, 'annotation': annotation_path},,,]
"""
assert os.path.exists(data_dir), "Cannot find dir = " + data_dir
anno_dir = os.path.join(data_dir, 'annotations_lane') # annotations_road, annotations_lane
im_dir = os.path.join(data_dir, 'image_2')
records = retrive_records(anno_dir, im_dir)
training_records, validation_records = split_records(records)
return training_records, validation_records
def retrive_records(anno_dir, im_dir):
res = []
annos = glob.glob(os.path.join(anno_dir, '*.png'))
for anno in annos:
fname = os.path.splitext(anno.split("/")[-1])[0]
cat, road_type, idx_str = tuple(fname.split('_'))
im = os.path.join(im_dir, cat+'_'+idx_str+'.png')
if not os.path.exists(im):
print '[ERROR] Annotation-Image Pair File Not Found:', im
continue
res.append({'image': im, 'annotation': anno})
return res
def test_data(im_dir):
assert os.path.exists(im_dir), "Cannot find dir = " + im_dir
res = []
if os.path.isfile(im_dir):
files = [im_dir]
if os.path.isdir(im_dir):
files = glob.glob(os.path.join(im_dir, '*.png'))
files.extend(glob.glob(os.path.join(im_dir, '*.jpg')))
for f in files:
fname = os.path.splitext(f.split("/")[-1])[0]
res.append({'image': f, 'name': fname})
return res
def split_records(records, validation_ratio=0.2):
random.shuffle(records)
val_size = int(validation_ratio * len(records))
return records[:val_size], records[val_size:]