-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_fit.jl
62 lines (45 loc) · 1.78 KB
/
model_fit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
using JBT
using Serialization
include("utils.jl")
function run_cohort_ketamine(func_R_history, model_label)
df = get_cohort_df(["1vs1", "4vs1", "ket"],
ID_excluded=["HO2_4", "HO2_11"],
S_excluded=["1vs1_1"])
(choices, data) = df_to_JBT(df, func_R_history)
model = logistic_past(choices, data)
chain = sample(model, NUTS(), MCMCThreads(), 2000, 4)
serialize(string("chain_ket_", model_label, ".jls"), chain)
end
function run_cohort_amphetamine(func_R_history, model_label)
df = get_cohort_df(["amph"])
(choices, data) = df_to_JBT(df, func_R_history)
model = logistic_past(choices, data)
chain = sample(model, NUTS(), MCMCThreads(), 2000, 4)
serialize(string("chain_amph_", model_label, ".jls"), chain)
end
function run_batch_ketamine_no_effect(func_R_history, model_label)
d = Dict(
"./exp/probe/baseline/baseline_no_effect/"=>"baseline",
"./exp/probe/ketamine/ketamine_no_effect/"=>"ketamine",
"./exp/probe/ketamine/vehicle_no_effect/"=>"vehicle"
)
df = get_batch_df(d)
(choices, data) = df_to_JBT(df, func_R_history)
model = logistic_past(choices, data)
chain = sample(model, NUTS(), MCMCThreads(), 1000, 4)
serialize(string("chain_", model_label, "_no_effect_batch.jls"), chain)
end
function run_batch_ketamine_effect(func_R_history, model_label)
d = Dict(
"./exp/probe/baseline/baseline_effect/"=>"baseline",
"./exp/probe/ketamine/ketamine_effect/"=>"ketamine",
"./exp/probe/ketamine/vehicle_effect/"=>"vehicle"
)
df = get_batch_df(d)
(choices, data) = df_to_JBT(df, func_R_history)
model = logistic_past(choices, data)
chain = sample(model, NUTS(), MCMCThreads(), 1000, 4)
serialize(string("chain_", model_label, "_effect_batch.jls"), chain)
end
run_cohort_amphetamine(reward_rate, "RR")
run_cohort_amphetamine(cue_reward_rate, "learning")