forked from ahmadianlab/gg3_nda
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtask_3_2_1.py
40 lines (33 loc) · 1.21 KB
/
task_3_2_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# import
from inference import *
from HMM_models import *
from HMM_inference import *
import numpy as np
ramp_priori = np.ones([len(beta_space), len(sigma_space)]) / (len(beta_space)*len(sigma_space))
step_priori = np.ones([len(m_space), len(r_space)]) / (len(m_space)*len(r_space))
iter = 100
N = 5
result_step = np.empty([])
error_step = 0
for i in range(iter):
m,r = sample_from_priori(step_priori, model = 'step')
shmm = HMM_Step(m, r, x0, Rh, T)
shmm_datas = generate_N_trials(N, shmm)
bayes = compute_bayes_factor(shmm_datas, ramp_priori, step_priori)
result_step = np.append(result_step, bayes)
if bayes < 0:
error_step += 1
#print('Step decision results', result_step)
print('Step decision accuracy', error_step / iter)
result_ramp = np.empty([])
error_ramp = 0
for i in range(iter):
beta, sigma = sample_from_priori(ramp_priori, model = 'ramp')
rhmm = HMM_Ramp(beta, sigma, K, x0, Rh, T)
rhmm_datas = generate_N_trials(N, rhmm)
bayes = compute_bayes_factor(rhmm_datas, ramp_priori, step_priori)
result_ramp = np.append(result_ramp, bayes)
if bayes > 0:
error_ramp += 1
#print('Ramp decision results', result_ramp)
print('Ramp decision accuracy', error_ramp / iter)