From d61f4ef132161c03daab75818098e0c704f73e14 Mon Sep 17 00:00:00 2001 From: Justin Date: Wed, 6 May 2020 15:38:08 -0500 Subject: [PATCH 1/4] dirichlet noise added to prior probabilities during self play --- Coach.py | 4 ++-- MCTS.py | 5 ++++- main.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Coach.py b/Coach.py index 8bf355c27..476cb7658 100644 --- a/Coach.py +++ b/Coach.py @@ -18,7 +18,7 @@ def __init__(self, game, nnet, args): self.nnet = nnet self.pnet = self.nnet.__class__(self.game) # the competitor network self.args = args - self.mcts = MCTS(self.game, self.nnet, self.args) + self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=True) self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples() @@ -82,7 +82,7 @@ def learn(self): end = time.time() for eps in range(self.args.numEps): - self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree + self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=True) # reset search tree iterationTrainExamples += self.executeEpisode() # bookkeeping + plot progress diff --git a/MCTS.py b/MCTS.py index 414133e58..f73884eb9 100644 --- a/MCTS.py +++ b/MCTS.py @@ -7,10 +7,11 @@ class MCTS(): This class handles the MCTS tree. """ - def __init__(self, game, nnet, args): + def __init__(self, game, nnet, args, dirichlet_noise=False): self.game = game self.nnet = nnet self.args = args + self.dirichlet_noise = dirichlet_noise self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited @@ -77,6 +78,8 @@ def search(self, canonicalBoard): if s not in self.Ps: # leaf node self.Ps[s], v = self.nnet.predict(canonicalBoard) + if self.dirichlet_noise: + self.Ps[s] = (0.75 * self.Ps[s]) + (0.25 * np.random.dirichlet([self.args.dirichletAlpha] * len(self.Ps[s]))) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s]*valids # masking invalid moves sum_Ps_s = np.sum(self.Ps[s]) diff --git a/main.py b/main.py index c7c3437e9..3444e27cc 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ 'numMCTSSims': 25, # Number of games moves for MCTS to simulate. 'arenaCompare': 40, # Number of games to play during arena play to determine if new net will be accepted. 'cpuct': 1, + 'dirichletAlpha': 0.6, # α = {0.3, 0.15, 0.03} for chess, shogi and Go respectively, scaled in inverse proportion to the approximate number of legal moves in a typical position 'checkpoint': './temp/', 'load_model': False, From 3bb5260f81e040d6ecf81fd01927f9b68ad11001 Mon Sep 17 00:00:00 2001 From: Justin Date: Sun, 10 May 2020 09:13:15 -0500 Subject: [PATCH 2/4] apply dirichlet noise only to valid moves and only at S0 --- MCTS.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/MCTS.py b/MCTS.py index f73884eb9..915217d69 100644 --- a/MCTS.py +++ b/MCTS.py @@ -30,7 +30,8 @@ def getActionProb(self, canonicalBoard, temp=1): proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): - self.search(canonicalBoard) + if i == 0 and self.dirichlet_noise: + self.search(canonicalBoard, dirichlet_noise=True) s = self.game.stringRepresentation(canonicalBoard) counts = [self.Nsa[(s,a)] if (s,a) in self.Nsa else 0 for a in range(self.game.getActionSize())] @@ -47,7 +48,7 @@ def getActionProb(self, canonicalBoard, temp=1): return probs - def search(self, canonicalBoard): + def search(self, canonicalBoard, dirichlet_noise=False): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that @@ -78,10 +79,10 @@ def search(self, canonicalBoard): if s not in self.Ps: # leaf node self.Ps[s], v = self.nnet.predict(canonicalBoard) - if self.dirichlet_noise: - self.Ps[s] = (0.75 * self.Ps[s]) + (0.25 * np.random.dirichlet([self.args.dirichletAlpha] * len(self.Ps[s]))) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s]*valids # masking invalid moves + if self.dirichlet_noise: + self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize @@ -99,6 +100,10 @@ def search(self, canonicalBoard): return -v valids = self.Vs[s] + if self.dirichlet_noise: + self.applyDirNoise(s, valids) + sum_Ps_s = np.sum(self.Ps[s]) + self.Ps[s] /= sum_Ps_s # renormalize cur_best = -float('inf') best_act = -1 @@ -130,3 +135,11 @@ def search(self, canonicalBoard): self.Ns[s] += 1 return -v + + def applyDirNoise(self, s, valids): + dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids)) + dir_idx = 0 + for idx in range(len(self.Ps[s])): + if self.Ps[s][idx]: + self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx]) + dir_idx += 1 From a0c40ebb3ce55c63efdda2ba1cfb73998476ce3a Mon Sep 17 00:00:00 2001 From: Justin Date: Sun, 10 May 2020 09:15:28 -0500 Subject: [PATCH 3/4] debug calling search with dir_noise --- MCTS.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MCTS.py b/MCTS.py index 915217d69..d29e5f6d5 100644 --- a/MCTS.py +++ b/MCTS.py @@ -30,8 +30,8 @@ def getActionProb(self, canonicalBoard, temp=1): proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): - if i == 0 and self.dirichlet_noise: - self.search(canonicalBoard, dirichlet_noise=True) + dir_noise = (i == 0 and self.dirichlet_noise) + self.search(canonicalBoard, dirichlet_noise=dir_noise) s = self.game.stringRepresentation(canonicalBoard) counts = [self.Nsa[(s,a)] if (s,a) in self.Nsa else 0 for a in range(self.game.getActionSize())] From 7cb2e1746d985d83121493cb1fbd5693ea57b2a4 Mon Sep 17 00:00:00 2001 From: Justin Date: Sat, 16 May 2020 10:17:26 -0500 Subject: [PATCH 4/4] use local paramter dirichlet_noise instead of self.dirichlet_noise in mcts.search --- MCTS.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MCTS.py b/MCTS.py index d29e5f6d5..68cef1dc9 100644 --- a/MCTS.py +++ b/MCTS.py @@ -81,7 +81,7 @@ def search(self, canonicalBoard, dirichlet_noise=False): self.Ps[s], v = self.nnet.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s]*valids # masking invalid moves - if self.dirichlet_noise: + if dirichlet_noise: self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: @@ -100,7 +100,7 @@ def search(self, canonicalBoard, dirichlet_noise=False): return -v valids = self.Vs[s] - if self.dirichlet_noise: + if dirichlet_noise: self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) self.Ps[s] /= sum_Ps_s # renormalize