diff --git a/othello/pytorch/NNet.py b/othello/pytorch/NNet.py index da189d1c4..24dd2954f 100644 --- a/othello/pytorch/NNet.py +++ b/othello/pytorch/NNet.py @@ -147,6 +147,6 @@ def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98 filepath = os.path.join(folder, filename) if not os.path.exists(filepath): - raise("No model in path {}".format(checkpoint)) + raise("No model in path {}".format(filepath)) checkpoint = torch.load(filepath) self.nnet.load_state_dict(checkpoint['state_dict']) diff --git a/pit.py b/pit.py index dec1a5a83..1a7ba0904 100644 --- a/pit.py +++ b/pit.py @@ -21,7 +21,7 @@ # nnet players n1 = NNet(g) -n1.load_checkpoint('./pretrained_models/','6x100x25_best.pth.tar') +n1.load_checkpoint('./pretrained_models/othello/pytorch/','6x100x25_best.pth.tar') args1 = dotdict({'numMCTSSims': 50, 'cpuct':1.0}) mcts1 = MCTS(g, n1, args1) n1p = lambda x: np.argmax(mcts1.getActionProb(x, temp=0))