-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdecodeMinist.py
86 lines (73 loc) · 2.82 KB
/
decodeMinist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import struct
from PIL import Image
# 训练集文件
train_images_idx3_ubyte_file = './datafile/train-images-idx3-ubyte/train-images.idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = './datafile/train-labels-idx1-ubyte/train-labels.idx1-ubyte'
# 测试集文件
test_images_idx3_ubyte_file = './datafile/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = './datafile/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte'
def decode_idx3_ubyte(idx3_ubyte_file):
"""
解析idx3文件的通用函数
:param idx3_ubyte_file: idx3文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>iiii'
magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
# 解析数据集
image_size = num_rows * num_cols
offset += struct.calcsize(fmt_header)
fmt_image = '>' + str(image_size) + 'B'
images = np.empty((num_images, image_size))
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('已解析 %d' % (i + 1) + '张')
images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset))
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
"""
解析idx1文件的通用函数
:param idx1_ubyte_file: idx1文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = open(idx1_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数和标签数
offset = 0
fmt_header = '>ii'
magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_image = '>B'
labels = np.empty(num_images)
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('已解析 %d' % (i + 1) + '张')
labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
offset += struct.calcsize(fmt_image)
return labels
def save_images(url):
i = 0
images = decode_idx3_ubyte(train_images_idx3_ubyte_file)
reshapeimages = [im.reshape(28, 28) for im in images]
for im in reshapeimages:
image = Image.new('L', (28, 28))
for x, y in zip(range(28), range(28)):
image.putpixel((y, x), (int(im[x][y]),))
image.save(url + str(i) + '.png')
i = i+1
print("end")
def vectorized_result(j):
e = np.zeros((1, 10))
e[0][j] = 1.0
return e