Skip to content

Commit

Permalink
Add cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisun committed Jan 14, 2018
1 parent 6df51be commit 091030e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
50 changes: 38 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@
import sys
import io

DATA_FILE = './poetry_no_title.txt'
BIG_FILE = './poetry_no_title.txt'
DATA_FILE = './poetry_no_title_data.txt'
VALIDATION_FILE = './poetry_no_title_validation.txt'
TARGET_FILE = './result.txt'
WEIGHTS_FILE = './weights.h5'
TRAIN_TEST_SPLIT = 0.7


class generator:
def __init__(self):
self.weight_file = WEIGHTS_FILE
self.f = open(TARGET_FILE, 'w', encoding='utf-8')
self.text = io.open(DATA_FILE, encoding='utf-8').read().lower()
self.text = io.open(BIG_FILE, encoding='utf-8').read()
print('corpus length:', len(self.text))
self.chars = sorted(list(set(self.text)))
print('char space size:', len(self.chars))

self.data_text = io.open(DATA_FILE, encoding='utf-8').read()
self.validation_text = io.open(VALIDATION_FILE, encoding='utf-8').read()


def sample(self, preds, temperature=1.0):
# helper function to sample an index from a probability array
preds = np.asarray(preds).astype('float64')
Expand Down Expand Up @@ -92,7 +100,7 @@ def train(self):
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

# cut the text in semi-redundant sequences of self.maxlen characters
MINI_BATCH_SIZE = 2048
MINI_BATCH_SIZE = 200
number_of_epoch = len(self.text)/MINI_BATCH_SIZE
self.maxlen = 5
step = 1
Expand All @@ -102,18 +110,36 @@ def train(self):
self.build_model()

print("training with epochs of: ", int(number_of_epoch))
self.model.fit_generator(self.generate_batch(),
steps_per_epoch=MINI_BATCH_SIZE,
epochs=int(number_of_epoch),
callbacks=[
LambdaCallback(on_epoch_end=self.save),
LambdaCallback(on_epoch_end=self.generate_sample_result)])
self.model.fit_generator(self.data_generator(),
steps_per_epoch=MINI_BATCH_SIZE,
epochs=int(number_of_epoch),
validation_data=self.validation_generator(),
# To give same number of batch size
validation_steps=MINI_BATCH_SIZE/TRAIN_TEST_SPLIT*(1-TRAIN_TEST_SPLIT),
callbacks=[
LambdaCallback(on_epoch_end=self.save),
LambdaCallback(on_epoch_end=self.generate_sample_result)])

def generate_batch(self):
def data_generator(self):
i = 0
while 1:
x = self.data_text[i: i + self.maxlen]
y = self.data_text[i + self.maxlen]

x_vec = np.zeros((1, self.maxlen, len(self.chars)), dtype=np.bool)
y_vec = np.zeros((1, len(self.chars)), dtype=bool)

y_vec[0, self.char_indices[y]] = 1
for t, char in enumerate(x):
x_vec[0, t, self.char_indices[char]] = 1
yield x_vec, y_vec
i += 1

def validation_generator(self):
i = 0
while 1:
x = self.text[i: i + self.maxlen]
y = self.text[i + self.maxlen]
x = self.validation_text[i: i + self.maxlen]
y = self.validation_text[i + self.maxlen]

x_vec = np.zeros((1, self.maxlen, len(self.chars)), dtype=np.bool)
y_vec = np.zeros((1, len(self.chars)), dtype=bool)
Expand Down
15 changes: 12 additions & 3 deletions preprocess_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# remove title from data
def main():
split_ratio = 0.7
f = open('./poetry.txt', encoding='utf-8')
dataFile = open('./poetry_no_title_data.txt', 'w', encoding='utf-8')
validationFile = open('./poetry_no_title_validation.txt', 'w', encoding='utf-8')

lineNumber = len(f.readlines())

f = open('./poetry.txt', encoding='utf-8')
targetFile = open('./poetry_no_title.txt', 'w', encoding='utf-8')
for i, line in enumerate(f.readlines()):
targetFile.write(line.split(':')[1])
print(i)
newLine = line.split(':')[1]
if (i < lineNumber * split_ratio):
dataFile.write(newLine)
else:
validationFile.write(newLine)


if __name__ == '__main__':
Expand Down

0 comments on commit 091030e

Please sign in to comment.