-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbandittestframe.py
33 lines (25 loc) · 1.04 KB
/
bandittestframe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from bernoulliarm import *
def test_algorithm(algo, arms, num_sims, horizon):
chosen_arms = [0.0 for i in range(num_sims * horizon)]
rewards = [0.0 for i in range(num_sims * horizon)]
cumulative_rewards = [0.0 for i in range(num_sims * horizon)]
sim_nums = [0.0 for i in range(num_sims * horizon)]
times = [0.0 for i in range(num_sims * horizon)]
for sim in range(num_sims):
sim = sim + 1
algo.initialize(len(arms))
for t in range(horizon):
t = t + 1
index = (sim - 1) * horizon + t - 1
sim_nums[index] = sim
times[index] = t
chosen_arm = algo.select_arm()
chosen_arms[index] = chosen_arm
reward = arms[chosen_arms[index]].draw()
rewards[index] = reward
if t == 1:
cumulative_rewards[index] = reward
else:
cumulative_rewards[index - 1] + reward
algo.update(chosen_arm, reward)
return [sim_nums, times, chosen_arms, rewards, cumulative_rewards]