From b6c6ed96a36185c62ef97aff36dec22835812df1 Mon Sep 17 00:00:00 2001 From: cestpasphoto Date: Thu, 11 Feb 2021 11:11:11 +0100 Subject: [PATCH] Add Dirichlet Noise Copy paste from https://github.com/suragnair/alpha-zero-general/pull/186 --- Coach.py | 4 ++-- MCTS.py | 28 ++++++++++++++++++++++------ main.py | 2 +- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/Coach.py b/Coach.py index 02391d5..b989ec9 100644 --- a/Coach.py +++ b/Coach.py @@ -24,7 +24,7 @@ def __init__(self, game, nnet, args): self.nnet = nnet self.pnet = self.nnet.__class__(self.game, self.nnet.args) # the competitor network self.args = args - self.mcts = MCTS(self.game, self.nnet, self.args) + self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0)) self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples() @@ -86,7 +86,7 @@ def learn(self): iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue) for _ in tqdm(range(self.args.numEps), desc="Self Play", ncols=100): - self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree + self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0)) # reset search tree iterationTrainExamples += self.executeEpisode() # save the iteration examples to the history diff --git a/MCTS.py b/MCTS.py index b4b0013..dd05e20 100644 --- a/MCTS.py +++ b/MCTS.py @@ -13,10 +13,11 @@ class MCTS(): This class handles the MCTS tree. """ - def __init__(self, game, nnet, args): + def __init__(self, game, nnet, args, dirichlet_noise=False): self.game = game self.nnet = nnet self.args = args + self.dirichlet_noise = dirichlet_noise self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited @@ -35,7 +36,8 @@ def getActionProb(self, canonicalBoard, temp=1): proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): - self.search(canonicalBoard) + dir_noise = (i == 0 and self.dirichlet_noise) + self.search(canonicalBoard, dirichlet_noise=dir_noise) s = self.game.stringRepresentation(canonicalBoard) counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())] @@ -52,7 +54,7 @@ def getActionProb(self, canonicalBoard, temp=1): probs = [x / counts_sum for x in counts] return probs - def search(self, canonicalBoard): + def search(self, canonicalBoard, dirichlet_noise=False): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that @@ -81,10 +83,11 @@ def search(self, canonicalBoard): return -self.Es[s] if s not in self.Ps: - # leaf node - self.Ps[s], v = self.nnet.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) - self.Ps[s] = self.Ps[s] * valids # masking invalid moves + # leaf node + self.Ps[s], v = self.nnet.predict(canonicalBoard, valids) + if dirichlet_noise: + self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize @@ -102,6 +105,10 @@ def search(self, canonicalBoard): return -v valids = self.Vs[s] + if dirichlet_noise: + self.applyDirNoise(s, valids) + sum_Ps_s = np.sum(self.Ps[s]) + self.Ps[s] /= sum_Ps_s # renormalize cur_best = -float('inf') best_act = -1 @@ -134,3 +141,12 @@ def search(self, canonicalBoard): self.Ns[s] += 1 return -v + + + def applyDirNoise(self, s, valids): + dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids)) + dir_idx = 0 + for idx in range(len(self.Ps[s])): + if valids[idx]: + self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx]) + dir_idx += 1 diff --git a/main.py b/main.py index 2d0d542..b7d6d54 100644 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ def main(): parser.add_argument('--maxlenOfQueue' , '-q' , action='store', default=200000, type=int , help='Number of game examples to train the neural networks') parser.add_argument('--numMCTSSims' , '-m' , action='store', default=25 , type=int , help='Number of games moves for MCTS to simulate.') parser.add_argument('--cpuct' , '-c' , action='store', default=1.0 , type=float, help='') - # parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position') + parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position') parser.add_argument('--numItersForTrainExamplesHistory', '-n', action='store', default=5, type=int, help='') parser.add_argument('--learn-rate' , '-l' , action='store', default=0.001, type=float, help='')