Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a typo in Coach.py #322

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -11,3 +11,13 @@ checkpoints/
# For PyCharm users
.idea/

# environment
myenv/
env/

# handy tests
test.py

*.ipynb
activate.sh
.ipynb_checkpoints/
75 changes: 56 additions & 19 deletions Arena.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging

from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import concurrent

log = logging.getLogger(__name__)

@@ -10,6 +12,7 @@ class Arena():
An Arena class where any 2 agents can be pit against each other.
"""

NUM_WORKERS = 16
def __init__(self, player1, player2, game, display=None):
"""
Input:
@@ -76,6 +79,7 @@ def playGame(self, verbose=False):
assert self.display
print("Game over: Turn ", str(it), "Result ", str(self.game.getGameEnded(board, 1)))
self.display(board)

return curPlayer * self.game.getGameEnded(board, curPlayer)

def playGames(self, num, verbose=False):
@@ -93,24 +97,57 @@ def playGames(self, num, verbose=False):
oneWon = 0
twoWon = 0
draws = 0
for _ in tqdm(range(num), desc="Arena.playGames (1)"):
gameResult = self.playGame(verbose=verbose)
if gameResult == 1:
oneWon += 1
elif gameResult == -1:
twoWon += 1
else:
draws += 1

self.player1, self.player2 = self.player2, self.player1

for _ in tqdm(range(num), desc="Arena.playGames (2)"):
gameResult = self.playGame(verbose=verbose)
if gameResult == -1:
oneWon += 1
elif gameResult == 1:
twoWon += 1
else:
draws += 1
with ThreadPoolExecutor(max_workers=Arena.NUM_WORKERS) as executor:
futures = [executor.submit(self.playGame, verbose=verbose) for _ in range(num)]
gameResult = []
with tqdm(total=num, desc=f"Arena.playGames (1) with {Arena.NUM_WORKERS} workers") as pbar:
for future in concurrent.futures.as_completed(futures):
try:
gameResult.append(future.result())
except Exception as e:
log.error(f"Exception in a worker: {e}")
finally:
pbar.update(1)
oneWon = gameResult.count(1)
twoWon = gameResult.count(-1)
draws = num - oneWon - twoWon

# for _ in tqdm(range(num), desc="Arena.playGames (1)"):
# gameResult = self.playGame(verbose=verbose)
# if gameResult == 1:
# oneWon += 1
# elif gameResult == -1:
# twoWon += 1
# else:
# draws += 1

# switch players (white and black)
self.player1, self.player2 = self.player2, self.player1
gameResult = []
futures = [executor.submit(self.playGame, verbose=verbose) for _ in range(num)]

with tqdm(total=num, desc=f"Arena.playGames (2) with {Arena.NUM_WORKERS} workers") as pbar:
for future in concurrent.futures.as_completed(futures):
try:
gameResult.append(future.result())
except Exception as e:
log.error(f"Exception in a worker: {e}")
finally:
pbar.update(1)

oneWon += gameResult.count(-1)
twoWon += gameResult.count(1)
draws = num * 2 - oneWon - twoWon

# for _ in tqdm(range(num), desc="Arena.playGames (2)"):
# gameResult = self.playGame(verbose=verbose)
# if gameResult == -1:
# oneWon += 1
# elif gameResult == 1:
# twoWon += 1
# else:
# draws += 1

return oneWon, twoWon, draws


48 changes: 32 additions & 16 deletions Coach.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,8 @@
from collections import deque
from pickle import Pickler, Unpickler
from random import shuffle

from concurrent.futures import ThreadPoolExecutor
import concurrent
import numpy as np
from tqdm import tqdm

@@ -23,7 +24,7 @@ class Coach():
def __init__(self, game, nnet, args):
self.game = game
self.nnet = nnet
self.pnet = self.nnet.__class__(self.game) # the competitor network
self.pnet = self.nnet.__class__(self.game, input_channels = self.nnet.args.input_channels, num_channels = self.nnet.args.num_channels) # the competitor network
self.args = args
self.mcts = MCTS(self.game, self.nnet, self.args)
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
@@ -41,32 +42,37 @@ def executeEpisode(self):
uses temp=0.

Returns:
trainExamples: a list of examples of the form (canonicalBoard, currPlayer, pi,v)
trainExamples: a list of examples of the form (canonicalBoard, pi, v)
pi is the MCTS informed policy vector, v is +1 if
the player eventually won the game, else -1.
"""
trainExamples = []
board = self.game.getInitBoard()
self.curPlayer = 1
curPlayer = 1
episodeStep = 0

mcts = MCTS(self.game, self.nnet, self.args) # reset search tree TODO: do we really need to reset?

while True:
episodeStep += 1
canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
canonicalBoard = self.game.getCanonicalForm(board, curPlayer)
temp = int(episodeStep < self.args.tempThreshold)

pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
pi = mcts.getActionProb(canonicalBoard, temp=temp)
sym = self.game.getSymmetries(canonicalBoard, pi)
for b, p in sym:
trainExamples.append([b, self.curPlayer, p, None])
trainExamples.append([b, curPlayer, p, None])

action = np.random.choice(len(pi), p=pi)
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
board, curPlayer = self.game.getNextState(board, curPlayer, action)

r = self.game.getGameEnded(board, self.curPlayer)
r = self.game.getGameEnded(board, curPlayer)

if r != 0:
return [(x[0], x[2], r * ((-1) ** (x[1] != self.curPlayer))) for x in trainExamples]
if r == 2:
# game draw. We did not collect any rewards.
# shall we drop these training examples?
r = 0
return [(x[0], x[2], r * ((-1) ** (x[1] != curPlayer))) for x in trainExamples]

def learn(self):
"""
@@ -84,10 +90,19 @@ def learn(self):
if not self.skipFirstSelfPlay or i > 1:
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

for _ in tqdm(range(self.args.numEps), desc="Self Play"):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
iterationTrainExamples += self.executeEpisode()

with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor:
# Launch async simulations
futures = [executor.submit(self.executeEpisode) for _ in range(self.args.numEps)]

with tqdm(total=self.args.numEps, desc=f"Self Play with {self.args.num_workers} workers") as pbar:
for future in concurrent.futures.as_completed(futures):
try:
iterationTrainExamples += future.result()
except Exception as e:
log.error(f"Exception in a worker: {e}")
finally:
pbar.update(1)

# save the iteration examples to the history
self.trainExamplesHistory.append(iterationTrainExamples)

@@ -119,7 +134,7 @@ def learn(self):
pwins, nwins, draws = arena.playGames(self.args.arenaCompare)

log.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws))
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold or (pwins + nwins) < self.args.arenaCompare * 0.15:
log.info('REJECTING NEW MODEL')
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
else:
@@ -135,6 +150,7 @@ def saveTrainExamples(self, iteration):
if not os.path.exists(folder):
os.makedirs(folder)
filename = os.path.join(folder, self.getCheckpointFile(iteration) + ".examples")
log.info(f"saving train examples: {filename}")
with open(filename, "wb+") as f:
Pickler(f).dump(self.trainExamplesHistory)
f.closed
111 changes: 73 additions & 38 deletions MCTS.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import math
import sys


import numpy as np

@@ -24,33 +26,46 @@ def __init__(self, game, nnet, args):

self.Es = {} # stores game.getGameEnded ended for board s
self.Vs = {} # stores game.getValidMoves for board s

def getActionProb(self, canonicalBoard, temp=1):
"""
This function performs numMCTSSims simulations of MCTS starting from
canonicalBoard.
Performs MCTS simulations starting from canonicalBoard, for numMCTSSims times

Returns:
probs: a policy vector where the probability of the ith action is
proportional to Nsa[(s,a)]**(1./temp)
"""
for i in range(self.args.numMCTSSims):

for _ in range(self.args.numMCTSSims):
self.search(canonicalBoard)

# Comput action probabilities
s = self.game.stringRepresentation(canonicalBoard)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

counts = np.array(
[self.Nsa.get((s, a), 0) for a in range(self.game.getActionSize())],
dtype=np.float32
)

if 'verbose' in self.args and self.args.verbose == 1:
total_counts = counts.sum()
probs = counts.reshape(canonicalBoard.shape)
MCTS.display(probs)
MCTS.display(probs / (total_counts + EPS))
s = self.game.stringRepresentation(canonicalBoard)
probs = np.array(self.Ps[s]).reshape(canonicalBoard.shape)
MCTS.display(probs)

if temp == 0:
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
bestA = np.random.choice(bestAs)
probs = [0] * len(counts)
bestA = np.random.choice(np.flatnonzero(counts == counts.max()))
probs = np.zeros_like(counts, dtype=np.float32)
probs[bestA] = 1
return probs

counts = [x ** (1. / temp) for x in counts]
counts_sum = float(sum(counts))
probs = [x / counts_sum for x in counts]
return probs
else:
counts = counts ** (1. / temp)
probs = counts / (counts.sum() + EPS)
return probs

def search(self, canonicalBoard):
"""
@@ -74,17 +89,24 @@ def search(self, canonicalBoard):

s = self.game.stringRepresentation(canonicalBoard)

# Check terminal state
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)

if self.Es[s] != 0:
# terminal node
if self.Es[s] == 2:
# draw
return 0
return -self.Es[s]

# Expand the leaf node
if s not in self.Ps:
# leaf node

self.Ps[s], v = self.nnet.predict(canonicalBoard)

valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
self.Ps[s] *= valids # masking invalid moves
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # renormalize
@@ -94,43 +116,56 @@ def search(self, canonicalBoard):
# NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
# If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.
log.error("All valid moves were masked, doing a workaround.")
self.Ps[s] = self.Ps[s] + valids
self.Ps[s] /= np.sum(self.Ps[s])
self.Ps[s] = valids / valids.sum()

self.Vs[s] = valids
self.Ns[s] = 0
return -v

valids = self.Vs[s]
cur_best = -float('inf')
best_act = -1
sqrt_Ns = math.sqrt(self.Ns[s] + EPS)

# Vectorized UCB calculation
ucb_values = np.array([
self.Qsa.get((s, a), 0) +
self.args.cpuct * self.Ps[s][a] * sqrt_Ns / (1 + self.Nsa.get((s, a), 0))
if valids[a] else -float('inf')
for a in range(self.game.getActionSize())
])

# pick the action with the highest upper confidence bound
for a in range(self.game.getActionSize()):
if valids[a]:
if (s, a) in self.Qsa:
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
1 + self.Nsa[(s, a)])
else:
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) # Q = 0 ?

if u > cur_best:
cur_best = u
best_act = a

a = best_act
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
best_act = np.argmax(ucb_values)
next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act)
next_s = self.game.getCanonicalForm(next_s, next_player)

v = self.search(next_s)

if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
if (s, best_act) in self.Qsa:
self.Qsa[(s, best_act)] = (self.Nsa[(s, best_act)] * self.Qsa[(s, best_act)] + v) / (self.Nsa[(s, best_act)] + 1)
self.Nsa[(s, best_act)] += 1

else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Qsa[(s, best_act)] = v
self.Nsa[(s, best_act)] = 1

self.Ns[s] += 1
return -v


@staticmethod
def display(board):
n = board.shape[0]
print(" ", end="")
for y in range(n):
print(y, end=" ")
print("")
print("-----------------------")
for y in range(n):
print(y, "|", end="") # print the row #
for x in range(n):
piece = board[x][y] # get the piece to print
print(f"{piece:.2f}", end=" ")
print("|")

print("-----------------------")

Loading