Skip to content

Commit

Permalink
Add Dirichlet Noise
Browse files Browse the repository at this point in the history
  • Loading branch information
cestpasphoto committed Feb 11, 2021
1 parent 865270f commit b6c6ed9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, game, nnet, args):
self.nnet = nnet
self.pnet = self.nnet.__class__(self.game, self.nnet.args) # the competitor network
self.args = args
self.mcts = MCTS(self.game, self.nnet, self.args)
self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0))
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples()

Expand Down Expand Up @@ -86,7 +86,7 @@ def learn(self):
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

for _ in tqdm(range(self.args.numEps), desc="Self Play", ncols=100):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0)) # reset search tree
iterationTrainExamples += self.executeEpisode()

# save the iteration examples to the history
Expand Down
28 changes: 22 additions & 6 deletions MCTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ class MCTS():
This class handles the MCTS tree.
"""

def __init__(self, game, nnet, args):
def __init__(self, game, nnet, args, dirichlet_noise=False):
self.game = game
self.nnet = nnet
self.args = args
self.dirichlet_noise = dirichlet_noise
self.Qsa = {} # stores Q values for s,a (as defined in the paper)
self.Nsa = {} # stores #times edge s,a was visited
self.Ns = {} # stores #times board s was visited
Expand All @@ -35,7 +36,8 @@ def getActionProb(self, canonicalBoard, temp=1):
proportional to Nsa[(s,a)]**(1./temp)
"""
for i in range(self.args.numMCTSSims):
self.search(canonicalBoard)
dir_noise = (i == 0 and self.dirichlet_noise)
self.search(canonicalBoard, dirichlet_noise=dir_noise)

s = self.game.stringRepresentation(canonicalBoard)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
Expand All @@ -52,7 +54,7 @@ def getActionProb(self, canonicalBoard, temp=1):
probs = [x / counts_sum for x in counts]
return probs

def search(self, canonicalBoard):
def search(self, canonicalBoard, dirichlet_noise=False):
"""
This function performs one iteration of MCTS. It is recursively called
till a leaf node is found. The action chosen at each node is one that
Expand Down Expand Up @@ -81,10 +83,11 @@ def search(self, canonicalBoard):
return -self.Es[s]

if s not in self.Ps:
# leaf node
self.Ps[s], v = self.nnet.predict(canonicalBoard)
valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
# leaf node
self.Ps[s], v = self.nnet.predict(canonicalBoard, valids)
if dirichlet_noise:
self.applyDirNoise(s, valids)
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # renormalize
Expand All @@ -102,6 +105,10 @@ def search(self, canonicalBoard):
return -v

valids = self.Vs[s]
if dirichlet_noise:
self.applyDirNoise(s, valids)
sum_Ps_s = np.sum(self.Ps[s])
self.Ps[s] /= sum_Ps_s # renormalize
cur_best = -float('inf')
best_act = -1

Expand Down Expand Up @@ -134,3 +141,12 @@ def search(self, canonicalBoard):

self.Ns[s] += 1
return -v


def applyDirNoise(self, s, valids):
dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids))
dir_idx = 0
for idx in range(len(self.Ps[s])):
if valids[idx]:
self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx])
dir_idx += 1
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main():
parser.add_argument('--maxlenOfQueue' , '-q' , action='store', default=200000, type=int , help='Number of game examples to train the neural networks')
parser.add_argument('--numMCTSSims' , '-m' , action='store', default=25 , type=int , help='Number of games moves for MCTS to simulate.')
parser.add_argument('--cpuct' , '-c' , action='store', default=1.0 , type=float, help='')
# parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position')
parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position')
parser.add_argument('--numItersForTrainExamplesHistory', '-n', action='store', default=5, type=int, help='')

parser.add_argument('--learn-rate' , '-l' , action='store', default=0.001, type=float, help='')
Expand Down

0 comments on commit b6c6ed9

Please sign in to comment.