forked from NTT123/a0-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgo_web_app.py
132 lines (108 loc) · 3.51 KB
/
go_web_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# go game web server
import pickle
import random
from argparse import ArgumentParser
from collections import defaultdict
import jax
import jax.numpy as jnp
from flask import Flask, jsonify, redirect, request, send_file, url_for
from play import play_one_move
from utils import env_step, import_class, reset_env
parser = ArgumentParser()
parser.add_argument("--game-class", default="go_game.GoBoard9x9", type=str)
parser.add_argument(
"--agent-class", default="resnet_policy.ResnetPolicyValueNet256", type=str
)
parser.add_argument("--ckpt-filename", default="go_agent_9x9_256.ckpt", type=str)
parser.add_argument("--num_simulations_per_move", default=1024, type=int)
enable_mcts = True
args = parser.parse_args()
env = import_class(args.game_class)()
agent = import_class(args.agent_class)(
input_dims=env.observation().shape,
num_actions=env.num_actions(),
)
with open(args.ckpt_filename, "rb") as f:
agent = agent.load_state_dict(pickle.load(f)["agent"])
agent = agent.eval()
all_games = defaultdict(lambda: import_class(args.game_class)())
def human_vs_agent(env, info):
"""A game of human vs agent."""
human_action = info["human_action"]
if human_action == -1:
# Agent goes first
env = reset_env(env)
env.render()
else:
if human_action == "pass":
human_action = env.num_actions() - 1
action = jnp.array(human_action, dtype=jnp.int32)
env, reward = env_step(env, action)
if env.is_terminated().item():
reward = reward.item()
if reward == 1:
msg = "You won!"
elif reward == -1:
msg = "You lost :-("
else:
msg = ""
return env, {
"action": -1,
"terminated": env.is_terminated().item(),
"current_board": env.board.reshape((-1,)).tolist(),
"msg": msg,
}
rng_key = jax.random.PRNGKey(random.randint(0, 999999))
action, action_weights, value = play_one_move(
agent,
env,
rng_key,
enable_mcts=enable_mcts,
num_simulations=args.num_simulations_per_move,
random_action=False,
)
del action_weights, value
env, reward = env_step(env, action)
reward = reward.item()
if reward == -1:
msg = "You won!"
elif reward == 1:
msg = "You lost :-("
else:
msg = ""
action = action.item()
if len(msg) == 0 and action == env.num_actions() - 1:
msg = "AI PASSED!"
return env, {
"action": action,
"terminated": env.is_terminated().item(),
"current_board": env.board.reshape((-1,)).tolist(),
"msg": msg,
}
app = Flask(__name__)
@app.route("/<int:gameid>/move", methods=["POST"])
def move(gameid):
env = all_games[gameid]
info = request.get_json()
env, res = human_vs_agent(env, info)
all_games[gameid] = env
return jsonify(res)
@app.route("/<int:gameid>", methods=["GET"])
def startgame(gameid):
all_games[gameid] = reset_env(all_games[gameid])
return send_file("./index.html")
@app.route("/")
def index():
env = import_class(args.game_class)()
gameid = random.randint(0, 999999)
all_games[gameid] = env
return redirect(url_for("startgame", gameid=gameid))
@app.route("/<int:gameid>/reset")
def reset(gameid):
all_games[gameid] = reset_env(all_games[gameid])
return {}
@app.route("/stone.ogg")
def stone():
return send_file("./stone.ogg")
if __name__ == "__main__":
app.run(host="0.0.0.0")