-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathestimate_ising.m
138 lines (124 loc) · 4.39 KB
/
estimate_ising.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
function [h0 J train_logical] = estimate_ising(iters)
load 'neuron_trains.mat' neuron_trains;
neuron_trains = cell2mat(neuron_trains);
[N T] = size(neuron_trains);
% Proportion used for training data
p_train = 0.8;
train_logical = false(T, 1);
train_logical(1:round(p_train*T)) = true;
train_logical = train_logical(randperm(T));
test_logical = ~train_logical;
neuron_trains_test = neuron_trains(:,test_logical);
neuron_trains = neuron_trains(:,train_logical);
mean_experiment = transpose(mean(neuron_trains,2));
mean_experiment_product = neuron_trains*transpose(neuron_trains)/size(neuron_trains,2);
% h0 = unifrnd(-1, 1, 1, N);
neuron_trains2 = (neuron_trains+1)/2;
% h0 = log(mean(neuron_trains2, 2)./(1-mean(neuron_trains2, 2)))*0.5;
h0 = transpose(atanh(mean_experiment));
h0 = transpose(h0);
J = unifrnd(-0.1, 0.1, N, N);
% J = zeros(N, N);
maxdiff = 1;
eta = 0.1;
alpha = 0;
prev_change_h0 = zeros(1, N);
prev_change_J = zeros(N, N);
itercount = 0;
% Measure deviations from experiment
sigma_diff = zeros(1,iters);
corr_diff = zeros(1,iters);
% Gradient Ascent
sample_size = 10000;
sigm0 = zeros(sample_size, N);
for i=1:sample_size
sigm0(i,:) = 2*(randi(2, 1, N)-1)-1;
end
best_diff = 1;
best_h0 = h0;
best_J = J;
while itercount < iters
tic;
% eta = eta/itercount;
itercount = itercount+1;
disp([itercount maxdiff]);
maxdiff = 0;
% Sample Ising Estimations
% [sigm, states] = sample_ising(sample_size, h0, J);
% [sigm, states] = sample_ising_exact(h0, J);
% [sigm, states] = sw_sample_ising(h0, J, sample_size, sigm0);
% if mod(itercount, 2) == 1
[sigm, states] = mh_sample_ising(1, sample_size, h0, J, 10, sigm0);
sigm0 = sigm;
% end
% [sigm, states] = gibbs_sample_ising(sample_size, h0, J, 100);
weighted_states = sigm.*repmat(transpose(states), 1, size(sigm, 2));
toc
% tic;
% Update h0
mean_sigma = sum(weighted_states);
diff = eta*(mean_experiment-mean_sigma) + alpha*prev_change_h0;
prev_change_h0 = diff;
h0 = h0 + diff;
sigma_diff(itercount) = mean(abs(mean_experiment-mean_sigma));
% toc
% tic;
% Update Jij
mean_product = transpose(sigm)*weighted_states;
diff = 0.5*eta*(mean_experiment_product-mean_product)+alpha*prev_change_J;
diff(logical(eye(size(diff)))) = 0;
maxdiff = max(max(max(maxdiff, abs(diff))));
prev_change_J = diff;
J = J + diff;
corr_diff(itercount) = sum(sum(abs(mean_experiment_product-mean_product)))/(N^2);
if maxdiff < best_diff
best_diff = maxdiff;
best_h0 = h0;
best_J =J;
end
toc
end
J = best_J;
h0 = best_h0;
% Plot deviation from experiment over time
figure;
subplot(2,1,1);
plot(1:itercount, sigma_diff(1:itercount));
title('Deviation of Mean Firing Rate');
subplot(2,1,2);
plot(1:itercount, corr_diff(1:itercount));
xlabel('# of Iterations');
title('Deviation of Mean Correlation');
mean_experiment = transpose(mean(neuron_trains_test,2));
mean_experiment_product = neuron_trains_test*transpose(neuron_trains_test)/size(neuron_trains_test,2);
% Mean responses
mrs = mean_sigma;
mers = mean_experiment;
% Mean products
num_entries = N*(N-1)/2;
meps = zeros(1,num_entries);
mps = zeros(1,num_entries);
k = 1;
for i = 1:N
for j = i+1:N
meps(k) = mean_experiment_product(i,j) - mean_experiment(i)*mean_experiment(j);
mps(k) = mean_product(i,j) - mean_sigma(i)*mean_sigma(j);
k = k+1;
end
end
% Plot predicted vs. empirical values
figure;
scatter(mers, mrs, 10, 'b', 'filled');
hold on;
xlabel('Mean Experimental Response');
ylabel('Mean Predicted Response');
lin = linspace(min(min(mers),min(mrs)),max(max(mers),max(mrs)),101);
plot(lin, lin, 'r');
figure;
scatter(meps, mps, 10, 'b', 'filled');
hold on;
xlabel('Mean Experimental Correlation');
ylabel('Mean Predicted Correlation');
lin = linspace(min(min(meps),min(mps)),max(max(meps),max(mps)),101);
plot(lin, lin, 'r');
end