-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
87 lines (67 loc) · 2.09 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# coding: utf-8
import os
from PIL import Image
import numpy as np
from captcha_generator import CHARSET
encoded_map = {}
decode_map = {}
for i, c in enumerate(CHARSET):
encoded_map[c] = i
decode_map[i] = c
class dataset(object):
def __init__(self, folder, batch_size):
self.folder = folder
self.batch_size = batch_size
self._data = []
self._labels = []
self._size = 0
self._pos = 0
self._load()
def _load(self):
"""Load all .png file in self.folder"""
data = []
labels = []
for root, _, files in os.walk(self.folder):
for filename in files:
if filename.endswith('.png'):
img = Image.open(os.path.join(root, filename))
data.append(np.array(img) / 255)
labels.append([encoded_map[c] for c in filename[:filename.rfind('.')]])
self._size = len(data)
self._data = data
self._labels = labels
self._shuffle()
def _shuffle(self):
z = list(zip(self._data, self._labels))
np.random.shuffle(z)
self._data, self._labels = zip(*z)
@property
def data(self):
return self._data
@property
def labels(self):
return self._labels
def next_batch(self):
"""
Returns:
batch: list
len(batch) = self.batch_size
labels: list
new_epoch: Boolean
Indicates whether a loop is completed
"""
batch = []
labels = []
count = 0
new_epoch = False
while count < self.batch_size:
need = self.batch_size - count
cur_pos = min(self._size, self._pos + need)
count += cur_pos - self._pos
batch.extend(self._data[self._pos:cur_pos])
labels.extend(self._labels[self._pos:cur_pos])
self._pos = cur_pos if cur_pos < self._size else 0
if cur_pos == self._size:
new_epoch = True
self._shuffle()
return batch, labels, new_epoch