-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathmodel.py
71 lines (61 loc) · 2.24 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging
import theano
import theano.tensor as T
from blocks.bricks.cost import CategoricalCrossEntropy, SquaredError
from blocks.bricks import MLP, Linear, Tanh, Softmax
from blocks.bricks.recurrent import LSTM
from blocks.initialization import IsotropicGaussian, Constant
logger = logging.getLogger('main.model')
floatX = theano.config.floatX
class MLPModel():
def __init__(self, name='MLP'):
self.non_lins = [Tanh(), Softmax()]
self.dims = [784, 100, 10]
self.default_lr = 0.1
self.name = name
def apply(self, input_, target):
mlp = MLP(self.non_lins, self.dims,
weights_init=IsotropicGaussian(0.01),
biases_init=Constant(0),
name=self.name)
mlp.initialize()
import ipdb; ipdb.set_trace()
probs = mlp.apply(T.flatten(input_, outdim=2))
probs.name = 'probs'
cost = CategoricalCrossEntropy().apply(target.flatten(), probs)
cost.name = "CE"
self.outputs = {}
self.outputs['probs'] = probs
self.outputs['cost'] = cost
class LSTMModel():
def __init__(self, name='LSTM'):
self.dims = [2, 7, 2]
self.default_lr = 0.01
self.name = name
def apply(self, input_, target):
x_to_h = Linear(name='x_to_h',
input_dim=self.dims[0],
output_dim=self.dims[1] * 4)
pre_rnn = x_to_h.apply(input_)
pre_rnn.name = 'pre_rnn'
rnn = LSTM(activation=Tanh(),
dim=self.dims[1], name=self.name)
h, _ = rnn.apply(pre_rnn)
h.name = 'h'
h_to_y = Linear(name='h_to_y',
input_dim=self.dims[1],
output_dim=self.dims[2])
y_hat = h_to_y.apply(h)
y_hat.name = 'y_hat'
cost = SquaredError().apply(target, y_hat)
cost.name = 'MSE'
self.outputs = {}
self.outputs['y_hat'] = y_hat
self.outputs['cost'] = cost
self.outputs['pre_rnn'] = pre_rnn
self.outputs['h'] = h
# Initialization
for brick in (rnn, x_to_h, h_to_y):
brick.weights_init = IsotropicGaussian(0.01)
brick.biases_init = Constant(0)
brick.initialize()