-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfull_model.py
265 lines (212 loc) · 9.91 KB
/
full_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# Learning the transition probabilities with the PES rule
# Uses a modified version of the Agent that takes in one 'value' at a time
import nengo
from nengo import spa
import numpy as np
from environment import Environment
from modelbasednode import Agent, AgentSplit
#import nengolib
from nengolib.signal import z
import scipy
DIM = 16#5#64
direct = False#True
learning = False#True
initialized = True
# If the output of the probability population is forced to be a probability that adds to 1
forced_prob = False#True
learning_rate=9e-5#5e-6
# Time between state transitions
time_interval = 0.1#0.5
nengo_seed=13
states = ['S0', 'S1', 'S2']
actions = ['L', 'R']
#n_sa_neurons = DIM*2*15 # number of neurons in the state+action population
n_sa_neurons = DIM*2*75 # number of neurons in the state+action population
n_prod_neurons = DIM*15 # number of neurons in the product network
# Set all vectors to be orthogonal for now (easy debugging)
vocab = spa.Vocabulary(dimensions=DIM)
#vocab = spa.Vocabulary(dimensions=DIM, randomize=False)
# TODO: these vectors might need to be chosen in a smarter way
for sp in states+actions:
vocab.parse(sp)
class AreaIntercepts(nengo.dists.Distribution):
dimensions = nengo.params.NumberParam('dimensions')
base = nengo.dists.DistributionParam('base')
def __init__(self, dimensions, base=nengo.dists.Uniform(-1, 1)):
super(AreaIntercepts, self).__init__()
self.dimensions = dimensions
self.base = base
def __repr(self):
return "AreaIntercepts(dimensions=%r, base=%r)" % (self.dimensions, self.base)
def transform(self, x):
sign = 1
if x > 0:
x = -x
sign = -1
return sign * np.sqrt(1-scipy.special.betaincinv((self.dimensions+1)/2.0, 0.5, x+1))
def sample(self, n, d=None, rng=np.random):
s = self.base.sample(n=n, d=d, rng=rng)
for i in range(len(s)):
s[i] = self.transform(s[i])
return s
# The ideal function that should be learned
def correct_mapping(x):
state = x[:DIM]
action = x[DIM:]
closest_state = find_closest_vector(state, index_to_state_vector)
closest_action = find_closest_vector(action, index_to_action_vector)
if closest_state == 0:
if closest_action == 0: # Left
return index_to_state_vector[1]*.7 + index_to_state_vector[2]*.3
elif closest_action == 1: # Right
return index_to_state_vector[1]*.3 + index_to_state_vector[2]*.7
else:
# Always return to state 0 at this point
return index_to_state_vector[0]
def initial_mapping(x):
state = x[:DIM]
action = x[DIM:]
closest_state = find_closest_vector(state, index_to_state_vector)
closest_action = find_closest_vector(action, index_to_action_vector)
if closest_state == 0:
if closest_action == 0: # Left
return index_to_state_vector[1]*.5 + index_to_state_vector[2]*.5
elif closest_action == 1: # Right
return index_to_state_vector[1]*.5 + index_to_state_vector[2]*.5
else:
# Always return to state 0 at this point
return index_to_state_vector[0]
def selected_error(t, x):
error = x[:DIM]
action1 = x[DIM:DIM*2]
action2 = x[DIM*2:]
res = np.zeros(DIM)
action_index1 = find_closest_vector(action1, index_to_action_vector)
action_index2 = find_closest_vector(action2, index_to_action_vector)
if action_index1 == action_index2:
return error
else:
return res
#FIXME: this is currently hardcoded for only 5 dimensions
def make_probability(t, x):
s0 = min(max(0, x[0]),1)
s1 = min(max(0, x[1]),1)
s2 = min(max(0, x[2]),1)
#total = np.sum(x[0], x[1], x[2])
total = s0 + s1 + s2
if total > 0:
return (s0/total, s1/total, s2/total, x[3], x[4])
else:
return x
# takes a state index and returns the corresponding vector for the semantic pointer
index_to_state_vector = np.zeros((len(states), DIM))
# takes an action index and returns the corresponding vector for the semantic pointer
index_to_action_vector = np.zeros((len(actions), DIM))
# Fill in mapping data structures based on the vocab given
for i, vk in enumerate(vocab.keys):
if vk in actions:
index_to_action_vector[actions.index(vk)] = vocab.vectors[i]
if vk in states:
index_to_state_vector[states.index(vk)] = vocab.vectors[i]
def find_closest_vector(vec, index_to_vector):
# Find the dot product with all other vectors
distance = 0
best_index = 0
for i, v in enumerate(index_to_vector):
d = np.dot(vec, v)
if d > distance:
distance = d
best_index = i
return best_index
intercept_dist=0
model = nengo.Network('RL P-learning', seed=nengo_seed)
with model:
cfg = nengo.Config(nengo.Ensemble, nengo.Connection)
if direct:
cfg[nengo.Ensemble].neuron_type = nengo.Direct()
cfg[nengo.Connection].synapse = Non
e
if intercept_dist == 0:
intercepts = nengo.dists.Uniform(-1,1)
elif intercept_dist == 1:
intercepts = AreaIntercepts(dimensions=DIM*2)
elif intercept_dist == 2:
intercepts = nengo.dists.Uniform(.3,1)
# Model of the external environment
# Input: action semantic pointer
# Output: current state semantic pointer
#model.env = nengo.Node(Environment(vocab=vocab, time_interval=time_interval), size_in=DIM, size_out=DIM)
#model.env = nengo.Node(Agent(vocab=vocab, time_interval=time_interval),
# size_in=2, size_out=DIM*3)
model.env = nengo.Node(AgentSplit(vocab=vocab, time_interval=time_interval),
size_in=1, size_out=DIM*4)
with cfg:
model.state = spa.State(DIM, vocab=vocab)
model.action = spa.State(DIM, vocab=vocab)
model.probability = spa.State(DIM, vocab=vocab)
# The action that is currently being used along with the state to calculate value
# If this matches with the actual action being taken, learning will happen (on the next step after a delay)
model.calculating_action = spa.State(DIM, vocab=vocab)
nengo.Connection(model.env[DIM*3:DIM*4], model.calculating_action.input)
if learning:
# State and selected action in one ensemble
model.state_and_action = nengo.Ensemble(n_neurons=n_sa_neurons, dimensions=DIM*2, intercepts=intercepts)
if initialized:
function = correct_mapping
else:
function= initial_mapping
conn = nengo.Connection(model.state_and_action, model.probability.input,
function=function,
learning_rule_type=nengo.PES(pre_synapse=z**(-int(time_interval*2*1000)),
learning_rate=learning_rate),
)
else:
with cfg:
# State and selected action in one ensemble
model.state_and_action = nengo.Ensemble(n_neurons=n_sa_neurons, dimensions=DIM*2, intercepts=intercepts)
nengo.Connection(model.state_and_action, model.probability.input, function=correct_mapping)
with cfg:
nengo.Connection(model.state.output, model.state_and_action[:DIM])
nengo.Connection(model.env[DIM*3:DIM*4], model.state_and_action[DIM:])
# Semantic pointer for the Q values of each state
# In the form of q0*S0 + q1*S1 + q2*S2
model.q = spa.State(DIM, vocab=vocab)
# Scalar value from the dot product of P and Q
model.value = nengo.Ensemble(100, 1)
#TODO: figure out what the result of P.Q is used for
model.prod = nengo.networks.Product(n_neurons=n_prod_neurons, dimensions=DIM)
if forced_prob:
normalized_prob = nengo.Node(make_probability, size_in=DIM, size_out=DIM)
nengo.Connection(model.probability.output, normalized_prob, synapse=None)
nengo.Connection(normalized_prob, model.prod.A, synapse=None)
else:
nengo.Connection(model.probability.output, model.prod.A)
#nengo.Connection(model.q.output, model.prod.B)
nengo.Connection(model.env[DIM*2:DIM*3], model.prod.B)
nengo.Connection(model.prod.output, model.value,
transform=np.ones((1,DIM)))
#TODO: doublecheck that this is the correct way to connect things
nengo.Connection(model.env[DIM:DIM*2], model.state.input)
#TODO: need to set up error signal and handle timing
model.error = spa.State(DIM, vocab=vocab)
##nengo.Connection(model.error.output, conn.learning_rule)
model.error_node = nengo.Node(selected_error,size_in=DIM*3, size_out=DIM)
nengo.Connection(model.error.output, model.error_node[:DIM])
nengo.Connection(model.action.output, model.error_node[DIM:DIM*2])
nengo.Connection(model.calculating_action.output, model.error_node[DIM*2:DIM*3])
if learning:
nengo.Connection(model.error_node, conn.learning_rule)
#TODO: figure out which way the sign goes, one should be negative, and the other positive
#TODO: figure out how to delay by one "time-step" correctly
nengo.Connection(model.state.output, model.error.input, transform=-1)
nengo.Connection(model.probability.output, model.error.input, transform=1,
synapse=z**(-int(time_interval*2*1000)))
#synapse=nengolib.synapses.PureDelay(500)) #500ms delay
# Testing the delay synapse to make sure it works as expected
model.state_delay_test = spa.State(DIM, vocab=vocab)
nengo.Connection(model.state.output, model.state_delay_test.input,
synapse=z**(-int(time_interval*2*1000)))
#nengo.Connection(model.value, model.env)
nengo.Connection(model.value, model.env, synapse=0.025)
nengo.Connection(model.env[:DIM], model.action.input) # Purely for plotting
nengo.Connection(model.env[DIM*2:DIM*3], model.q.input) # Purely for plotting