Skip to content

Commit

Permalink
Code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
adbrebs committed Apr 13, 2016
1 parent 2eb7432 commit 96f1919
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 226 deletions.
35 changes: 18 additions & 17 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,30 @@ def load_data(filename='hand_training.hdf5'):
training_data_file = os.path.join(data_folder, filename)
train_data = h5py.File(training_data_file, 'r')

coord_seq = train_data['pt_seq'][:]
coord_idx = train_data['pt_idx'][:]
pt_seq = train_data['pt_seq'][:]
pt_idx = train_data['pt_idx'][:]
strings_seq = train_data['str_seq'][:]
strings_idx = train_data['str_idx'][:]

train_data.close()
return coord_seq, coord_idx, strings_seq, strings_idx
return pt_seq, pt_idx, strings_seq, strings_idx


def create_generator(shuffle, batch_size, seq_coord, coord_idx,
def create_generator(shuffle, batch_size, seq_pt, pt_idx,
seq_strings, strings_idx, chunk=None):
n_seq = coord_idx.shape[0]
idx = np.arange(n_seq)
np.random.shuffle(idx)
n_seq = pt_idx.shape[0]

coord_idx = coord_idx[idx]
strings_idx = strings_idx[idx]
if shuffle:
idx = np.arange(n_seq)
np.random.shuffle(idx)
pt_idx = pt_idx[idx]
strings_idx = strings_idx[idx]

def generator():
for i in range(0, n_seq-batch_size, batch_size):
pt, pt_mask, str, str_mask = \
extract_sequence(slice(i, i + batch_size),
seq_coord, coord_idx, seq_strings, strings_idx)
seq_pt, pt_idx, seq_strings, strings_idx)

pt_input = pt[:-1]
pt_tg = pt[1:]
Expand All @@ -55,18 +56,18 @@ def generator():
return generator


def extract_sequence(slice, coord, coord_idx, strings, str_idx, M=None):
def extract_sequence(s, pt, pt_idx, strings, str_idx, M=None):
"""
the slice represents the minibatch
- coord: shape (number points, 3)
- coord_idx: shape (number of sequences, 2): the starting and end points of
the slice s represents the minibatch
- pt: shape (number points, 3)
- pt_idx: shape (number of sequences, 2): the starting and end points of
each sequence
"""
if not M:
M = 1500

pt_idxs = coord_idx[slice]
str_idxs = str_idx[slice]
pt_idxs = pt_idx[s]
str_idxs = str_idx[s]

longuest_pt_seq = max([b - a for a, b in pt_idxs])
longuest_pt_seq = min(M, longuest_pt_seq)
Expand All @@ -78,7 +79,7 @@ def extract_sequence(slice, coord, coord_idx, strings, str_idx, M=None):
str_mask_batch = np.zeros((longuest_str_seq, len(str_idxs)), dtype=floatX)

for i, (pt_seq, str_seq) in enumerate(zip(pt_idxs, str_idxs)):
pts = coord[pt_seq[0]:pt_seq[1]]
pts = pt[pt_seq[0]:pt_seq[1]]
limit2 = min(pts.shape[0], longuest_pt_seq)
pt_batch[:limit2, i] = pts[:limit2]
pt_mask_batch[:limit2, i] = 1
Expand Down
28 changes: 10 additions & 18 deletions data_raw2hdf5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#!/usr/bin/env python

"""
Generate .hdf5 training and validation datasets from the raw files.
Inspired from Alex Graves' pipeline.
"""

import time
import os
from os.path import join
Expand All @@ -24,6 +29,7 @@

char_dic, _ = cPickle.load(open('char_dict.pkl', 'r'))


def get_target_string(stroke_filename):

ascii_filename = re.sub('lineStrokes', 'ascii', stroke_filename)
Expand All @@ -50,9 +56,9 @@ def read_file(file_path):
pts = []
pre_pt = np.array([])
for trace in parse(file_path).getElementsByTagName('Stroke'):
for coords in trace.getElementsByTagName('Point'):
pt = np.array([coords.getAttribute('x').strip(),
coords.getAttribute('y').strip(), 0],
for pts in trace.getElementsByTagName('Point'):
pt = np.array([pts.getAttribute('x').strip(),
pts.getAttribute('y').strip(), 0],
dtype='float32')
if not len(pre_pt):
pre_pt = pt
Expand Down Expand Up @@ -99,26 +105,12 @@ def read_file(file_path):
print time.clock() - start


# M_x = pt_seq[:, 0].mean()
# M_y = pt_seq[:, 1].mean()
# s_x = pt_seq[:, 0].std()
# s_y = pt_seq[:, 1].std()

# Normalize
pt_seq[:, 0] = (pt_seq[:, 0] - M_x) / s_x
pt_seq[:, 1] = (pt_seq[:, 1] - M_y) / s_y


#
from utilities import plot_coord
from data import extract_sequence
pt_batch, pt_mask_batch, str_batch = \
extract_sequence(slice(0, 4), pt_seq, pt_idx, str_seq, str_idx)
plot_coord(pt_batch, pt_mask_batch, use_mask=True, show=True)




# Write the dataset
f = h5py.File(h5_filename, 'w')

ds_points = f.create_dataset(
Expand Down
53 changes: 0 additions & 53 deletions essai.py

This file was deleted.

18 changes: 9 additions & 9 deletions extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from raccoon.extensions import Saver, ValMonitor

from data import char2int
from utilities import plot_coord, plot_generated_sequences
from utilities import plot_seq_pt, plot_generated_sequences

floatX = theano.config.floatX

Expand All @@ -33,9 +33,9 @@ def execute_virtual(self, batch_id):
sample = self.fun_pred(np.zeros((self.n_samples, 3), floatX),
np.zeros((self.n_samples, self.n_hidden), floatX))

plot_coord(sample,
folder_path=self.folder_path,
file_name='{}_'.format(batch_id) + self.file_name)
plot_seq_pt(sample,
folder_path=self.folder_path,
file_name='{}_'.format(batch_id) + self.file_name)

return ['executed']

Expand All @@ -62,7 +62,7 @@ def __init__(self, name_extension, freq, folder_path, file_name,
self.bias_value = bias_value

# Initial values
self.coord_ini_mat = np.zeros((n_samples, 3), floatX)
self.pt_ini_mat = np.zeros((n_samples, 3), floatX)
self.h_ini_mat = np.zeros((n_samples, model.n_hidden), floatX)
self.k_ini_mat = np.zeros((n_samples, model.n_mixt_attention), floatX)
self.w_ini_mat = np.zeros((n_samples, model.n_chars), floatX)
Expand All @@ -71,18 +71,18 @@ def execute_virtual(self, batch_id):

cond, cond_mask = char2int(self.sample_strings, self.dict_char2int)

coord_gen, a_gen, k_gen, p_gen, w_gen, mask_gen = self.f_sampling(
self.coord_ini_mat, cond, cond_mask,
pt_gen, a_gen, k_gen, p_gen, w_gen, mask_gen = self.f_sampling(
self.pt_ini_mat, cond, cond_mask,
self.h_ini_mat, self.k_ini_mat, self.w_ini_mat, self.bias_value)

# plot_coord(coord_gen,
# plot_seq_pt(pt_gen,
# folder_path=self.folder_path,
# file_name='{}_'.format(batch_id) + self.file_name)
p_gen = np.swapaxes(p_gen, 1, 2)
mats = [(a_gen, 'alpha'), (k_gen, 'kapa'), (p_gen, 'phi'),
(w_gen, 'omega')]
plot_generated_sequences(
coord_gen, mats,
pt_gen, mats,
mask_gen, folder_path=self.folder_path,
file_name='{}_'.format(batch_id) + self.file_name)

Expand Down
Loading

0 comments on commit 96f1919

Please sign in to comment.