forked from chenyuntc/pytorch-book
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
194 lines (174 loc) · 7.38 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# coding:utf-8
import sys
import os
import json
import re
import numpy as np
def _parseRawData(author=None, constrain=None, src='./chinese-poetry/json/simplified', category="poet.tang"):
"""
code from https://github.com/justdark/pytorch-poetry-gen/blob/master/dataHandler.py
处理json文件,返回诗歌内容
@param: author: 作者名字
@param: constrain: 长度限制
@param: src: json 文件存放路径
@param: category: 类别,有poet.song 和 poet.tang
返回 data:list
['床前明月光,疑是地上霜,举头望明月,低头思故乡。',
'一去二三里,烟村四五家,亭台六七座,八九十支花。',
.........
]
"""
def sentenceParse(para):
# para 形如 "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,
# 生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)
# 好是去塵俗,煙花長一欄。"
result, number = re.subn(u"(.*)", "", para)
result, number = re.subn(u"{.*}", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"[\]\[]", "", result)
r = ""
for s in result:
if s not in set('0123456789-'):
r += s
r, number = re.subn(u"。。", u"。", r)
return r
def handleJson(file):
# print file
rst = []
data = json.loads(open(file).read())
for poetry in data:
pdata = ""
if (author is not None and poetry.get("author") != author):
continue
p = poetry.get("paragraphs")
flag = False
for s in p:
sp = re.split(u"[,!。]", s)
for tr in sp:
if constrain is not None and len(tr) != constrain and len(tr) != 0:
flag = True
break
if flag:
break
if flag:
continue
for sentence in poetry.get("paragraphs"):
pdata += sentence
pdata = sentenceParse(pdata)
if pdata != "":
rst.append(pdata)
return rst
data = []
for filename in os.listdir(src):
if filename.startswith(category):
data.extend(handleJson(src + filename))
return data
def pad_sequences(sequences,
maxlen=None,
dtype='int32',
padding='pre',
truncating='pre',
value=0.):
"""
code from keras
Pads each sequence to the same length (length of the longest sequence).
If maxlen is provided, any sequence longer
than maxlen is truncated to maxlen.
Truncation happens off either the beginning (default) or
the end of the sequence.
Supports post-padding and pre-padding (default).
Arguments:
sequences: list of lists where each element is a sequence
maxlen: int, maximum length
dtype: type to cast the resulting sequence.
padding: 'pre' or 'post', pad either before or after each sequence.
truncating: 'pre' or 'post', remove values from sequences larger than
maxlen either in the beginning or in the end of the sequence
value: float, value to pad the sequences to the desired value.
Returns:
x: numpy array with dimensions (number_of_sequences, maxlen)
Raises:
ValueError: in case of invalid values for `truncating` or `padding`,
or in case of invalid shape for a `sequences` entry.
"""
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
num_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0: # pylint: disable=g-explicit-length-test
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s): # pylint: disable=g-explicit-length-test
continue # empty list/array was found
if truncating == 'pre':
trunc = s[-maxlen:] # pylint: disable=invalid-unary-operand-type
elif truncating == 'post':
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" not understood' % truncating)
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError(
'Shape of sample %s of sequence at position %s is different from '
'expected shape %s'
% (trunc.shape[1:], idx, sample_shape))
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x
def get_data(opt):
"""
@param opt 配置选项 Config对象
@return word2ix: dict,每个字对应的序号,形如u'月'->100
@return ix2word: dict,每个序号对应的字,形如'100'->u'月'
@return data: numpy数组,每一行是一首诗对应的字的下标
"""
if os.path.exists(opt.pickle_path):
data = np.load(opt.pickle_path, allow_pickle=True)
data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item()
return data, word2ix, ix2word
# 如果没有处理好的二进制文件,则处理原始的json文件
data = _parseRawData(opt.author, opt.constrain, opt.data_path, opt.category)
words = {_word for _sentence in data for _word in _sentence}
word2ix = {_word: _ix for _ix, _word in enumerate(words)}
word2ix['<EOP>'] = len(word2ix) # 终止标识符
word2ix['<START>'] = len(word2ix) # 起始标识符
word2ix['</s>'] = len(word2ix) # 空格
ix2word = {_ix: _word for _word, _ix in list(word2ix.items())}
# 为每首诗歌加上起始符和终止符
for i in range(len(data)):
data[i] = ["<START>"] + list(data[i]) + ["<EOP>"]
# 将每首诗歌保存的内容由‘字’变成‘数’
# 形如[春,江,花,月,夜]变成[1,2,3,4,5]
new_data = [[word2ix[_word] for _word in _sentence]
for _sentence in data]
# 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的
pad_data = pad_sequences(new_data,
maxlen=opt.maxlen,
padding='pre',
truncating='post',
value=len(word2ix) - 1)
# 保存成二进制文件
np.savez_compressed(opt.pickle_path,
data=pad_data,
word2ix=word2ix,
ix2word=ix2word)
return pad_data, word2ix, ix2word