-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmcplanner.cc
83 lines (65 loc) · 1.77 KB
/
mcplanner.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <iostream>
#include <cmath>
#include "mcplanner.hh"
/*
*
*/
void MCPlanner::setMaxQueries(int queries)
{
/* Store sweet queries */
maxQueries = queries;
/* Set max depth here based on epsilon-horizon time */
double gamma = domain->getDiscountFactor();
double rmax = domain->getRmax();
maxDepth = (int) ceil( log(epsilon*(1-gamma)/rmax)/log(gamma) );
}
/*
*
*/
Action MCPlanner::plan(State s)
{
/* Clear out data structues */
reset();
/* Store initial number of simulated steps */
numInitialSamples = domain->getNumSamples();
/* Do your rollouts */
while((domain->getNumSamples() - numInitialSamples) < maxQueries) {
search(0, s, false);
}
/* Return the greedy action */
Action a = selectAction(s, 0, true);
return a;
}
/*
*
*/
double MCPlanner::search(int depth, State s, bool terminal)
{
double q;
/* Terminate if a terminal state has been reached */
if(terminal) {
return 0;
}
/* Return if exceeded the maximum rollout depth
* (possible use a heuristic instead of zero)
*/
if(depth > maxDepth) {
return 0;
}
/* select the next action non-greedily
* in other words according to some heuristic */
Action a = selectAction(s, depth, false);
/* Sample next state from generative model */
SARS *sars = domain->simulate(s,a);
/* Grab discount factor */
double gamma = domain->getDiscountFactor();
/* Compute the Q-value */
q = sars->reward + gamma * search(depth + 1, sars->s_prime, sars->terminal);
/* Update the Q-value (only if we have not exceeded our samples */
if((domain->getNumSamples() - numInitialSamples) < maxQueries) { // Pretty me?
double vmax = (1-pow(gamma, maxDepth-depth+1))/(1-gamma);
updateValue(depth, sars, q/vmax);
}
delete sars;
return q;
}