From 375b37be74425e3a180e622d1395ca70452d5ffa Mon Sep 17 00:00:00 2001 From: Mohammad Mahdi Rahimi Date: Tue, 26 Oct 2021 17:55:30 +0900 Subject: [PATCH] add model agent and update others --- pz_risk/agents/__init__.py | 6 +++-- pz_risk/agents/greedy.py | 28 +++++++++++++------- pz_risk/agents/model.py | 53 ++++++++++++++++++++++++++++++++++++++ pz_risk/agents/random.py | 2 +- pz_risk/agents/sampling.py | 10 ------- pz_risk/agents/value.py | 50 +++++++++++++++++++++-------------- 6 files changed, 107 insertions(+), 42 deletions(-) create mode 100644 pz_risk/agents/model.py diff --git a/pz_risk/agents/__init__.py b/pz_risk/agents/__init__.py index cd25304..38a8a00 100644 --- a/pz_risk/agents/__init__.py +++ b/pz_risk/agents/__init__.py @@ -1,6 +1,8 @@ from .base import BaseAgent from .greedy import GreedyAgent from .random import RandomAgent - +from .model import ModelAgent from .value import warm_up -warm_up() \ No newline at end of file +print('Start Warmp up...') +warm_up() +print('Warm up Done.') diff --git a/pz_risk/agents/greedy.py b/pz_risk/agents/greedy.py index 9c93cf3..6620574 100644 --- a/pz_risk/agents/greedy.py +++ b/pz_risk/agents/greedy.py @@ -1,5 +1,6 @@ import numpy as np +from core.board import Board from agents.base import BaseAgent from core.gamestate import GameState from agents.sampling import SAMPLING @@ -7,6 +8,7 @@ from loguru import logger + class GreedyAgent(BaseAgent): def __init__(self, player_id): super(GreedyAgent, self).__init__() @@ -15,15 +17,21 @@ def __init__(self, player_id): def reset(self): pass - def act(self, state): - if state['game_state'] == GameState.Attack: - attack_edge = state['board'].player_attack_edges(self.player_id) - base = manual_advantage(state, self.player_id, (1, None)) # Attack Finished - v = [manual_advantage(state, self.player_id, (False, ae)) for ae in attack_edge] - edge = attack_edge[np.argmax(v)] - logger.info('Attack values:{}, base: {}'.format(v, base)) - return (1, (None, None)) if base > max(v) else (0, edge) - else: - return SAMPLING[state['game_state']](state['board'], self.player_id) + def act(self, state: Board): + # if state.state == GameState.Attack: + # attack_edge = state.player_attack_edges(self.player_id) + # base = manual_advantage(state, self.player_id, (1, None)) # Attack Finished + # v = [manual_advantage(state, self.player_id, (False, ae)) for ae in attack_edge] + # edge = attack_edge[np.argmax(v)] + # # logger.info('Attack values:{}, base: {}'.format(v, base)) + # return (1, (None, None)) if base > max(v) else (0, edge) + # else: + # Use Model to Gather Future State per Valid Actions + action_scores = [] + deterministic, valid_actions = state.valid_actions(self.player_id) + for valid_action in valid_actions: + action_scores.append(manual_advantage(state, self.player_id, valid_action)) + action = valid_actions[np.argmax(action_scores)] + return action # Find the action with highest advantage # Execute the action diff --git a/pz_risk/agents/model.py b/pz_risk/agents/model.py new file mode 100644 index 0000000..b71e018 --- /dev/null +++ b/pz_risk/agents/model.py @@ -0,0 +1,53 @@ +import os +import numpy as np + +from core.board import Board +from agents.base import BaseAgent + +import torch +from copy import deepcopy +from agents.value import get_future, get_attack_dist +from utils import get_feat_adj_from_board +from training.dvn import DVNAgent + + +class ModelAgent(BaseAgent): + def __init__(self, player_id, device='cuda:0'): + super(ModelAgent, self).__init__() + self.player_id = player_id + self.device = device + feat_size = 14 # e.observation_spaces['feat'].shape[0] + hidden_size = 20 + + self.critic = DVNAgent(feat_size, hidden_size) + save_path = './trained_models4/' + load = 8 + self.critic.load_state_dict(torch.load(os.path.join(save_path, str(load) + ".pt"))) + self.critic.eval() + + # feat = torch.tensor(state['feat'], dtype=torch.float32, device=device).reshape(-1, 48, feat_size) + # adj = torch.tensor(state['adj'], dtype=torch.float32, device=device).reshape(-1, 48, 48) + + def reset(self): + pass + + def act(self, state: Board): + action_scores = [] + deterministic, valid_actions = state.valid_actions(self.player_id) + for valid_action in valid_actions: + sim = deepcopy(state) + if deterministic: + sim.step(self.player_id, valid_action) + else: + dist = get_attack_dist(state, valid_action) + if len(dist): # TODO: Change to sampling + left = get_future(dist, mode='most') + sim.step(self.player_id, valid_action, left) + else: + sim.step(self.player_id, valid_action) + sim_feat, sim_adj = get_feat_adj_from_board(sim, self.player_id, 6, 6) + sim_feat = torch.tensor(sim_feat, dtype=torch.float32, device=self.device).reshape(-1, 48, 14) + sim_adj = torch.tensor(sim_adj, dtype=torch.float32, device=self.device).reshape(-1, 48, 48) + action_scores.append(self.critic(sim_feat, sim_adj).detach().cpu().numpy()[:, 42 + self.player_id]) + action = valid_actions[np.argmax(action_scores)] + return action diff --git a/pz_risk/agents/random.py b/pz_risk/agents/random.py index 8cff87f..f254d0b 100644 --- a/pz_risk/agents/random.py +++ b/pz_risk/agents/random.py @@ -11,4 +11,4 @@ def reset(self): pass def act(self, state): - return SAMPLING[state['game_state']](state['board'], self.player_id) + return SAMPLING[state.state](state, self.player_id) diff --git a/pz_risk/agents/sampling.py b/pz_risk/agents/sampling.py index 06491e1..5932ec5 100644 --- a/pz_risk/agents/sampling.py +++ b/pz_risk/agents/sampling.py @@ -6,17 +6,7 @@ def sample_reinforce(board, player): - # num_units = board.players[player].placement nodes = board.player_nodes(player) - # r2 = {n: 0 for n in nodes} - # for _ in range(num_units): - # i = np.random.choice(nodes) - # r2[i] += 1 - # branches = math.comb(num_units + len(nodes) - 1, num_units) - # index = np.random.randint(branches) - # r = sorted([(index // num_units ** i) % num_units for i in range(len(nodes))]) - # r.append(num_units) - # r2 = {n: r[i + 1] - r[i] for i, n in zip(range(len(nodes)), random.sample(nodes, len(nodes)))} return np.random.choice(nodes) # [(0 if n not in nodes else r2[n]) for n in board.g.nodes()] diff --git a/pz_risk/agents/value.py b/pz_risk/agents/value.py index 4162529..eb347b9 100644 --- a/pz_risk/agents/value.py +++ b/pz_risk/agents/value.py @@ -7,12 +7,12 @@ # From https://web.stanford.edu/~guertin/risk.notes.html win_rate = np.array([ - [[0.417, 0.583, 0.], # 1 vs 1 - [0.255, 0.745, 0.]], # 1 vs 2 - [[0.579, 0.421, 0.], # 2 vs 1 - [0.228, 0.324, 0.448]], # 2 vs 2 - [[0.660, 0.340, 0.], # 3 vs 1 - [0.371, 0.336, 0.293]] # 3 vs 2 + [[0.417, 0.583, 0.], # 1 vs 1 + [0.255, 0.745, 0.]], # 1 vs 2 + [[0.579, 0.421, 0.], # 2 vs 1 + [0.228, 0.324, 0.448]], # 2 vs 2 + [[0.660, 0.340, 0.], # 3 vs 1 + [0.371, 0.336, 0.293]] # 3 vs 2 ]) d3 = {} @@ -69,9 +69,9 @@ def get_future(dist, mode='safe', risk=0.0): max_index = np.argmax([d[1] for d in dist]) return dist[max_index][0] elif mode == 'two': - left = [d[1] for d in dist if d[1] < 0] + left = [d[1] for d in dist if d[1] <= 0] right = [d[1] for d in dist if d[1] > 0] - left_score = dist[np.argmax(left)][0] * sum(left) + left_score = dist[np.argmax(left)][0] * (sum(left) + risk) right_score = dist[np.argmax(right)][0] * sum(right) return left_score + right_score elif mode == 'all': @@ -80,11 +80,14 @@ def get_future(dist, mode='safe', risk=0.0): def manual_value(board, player): - num_lands = len(board.player_nodes(player)) * 5 - num_units = board.player_units(player) - group_reward = board.player_group_reward(player) * 10 + my_lands = len(board.player_nodes(player)) + opp_lands = max([len(board.player_nodes(p)) for p in range(6) if p != player]) + my_units = board.player_units(player) + opp_units = max([board.player_units(p) for p in range(6) if p != player]) + my_group_reward = board.player_group_reward(player) + opp_group_reward = max([board.player_group_reward(p) for p in range(6) if p != player]) num_cards = sum([len(c) for c in board.players[player].cards.values()]) - return num_lands + num_units + group_reward + num_cards + return my_units + my_lands * 3 + my_group_reward * 10 + num_cards # (my_lands - opp_lands) + (my_group_reward - opp_group_reward)*2 def man_q_deterministic(board, player, action): @@ -93,21 +96,30 @@ def man_q_deterministic(board, player, action): return manual_value(sim, player) -def man_q_attack(board, player, action): +def get_attack_dist(board, action): attack_finished = action[0] - sim = deepcopy(board) + dist = [] if not attack_finished: src = action[1][0] trg = action[1][1] src_unit = board.g.nodes[src]['units'] trg_unit = board.g.nodes[trg]['units'] - dist = [(i, get_chance(src_unit, trg_unit, i)) for i in range(-trg_unit, src_unit+1)] - left = get_future(dist, mode='most') + dist = [(i, get_chance(src_unit, trg_unit, i)) for i in range(-trg_unit, src_unit + 1)] + return dist + + +def man_q_attack(board, player, action): + sim = deepcopy(board) + dist = get_attack_dist(board, action) + if len(dist) > 0: + left = get_future(dist, mode='all') sim.step(player, action, left) + else: + sim.step(player, action) return manual_value(sim, player) -def manual_q(state, player, action): +def manual_q(board, player, action): Q = { GameState.Reinforce: man_q_deterministic, GameState.Attack: man_q_attack, @@ -117,8 +129,8 @@ def manual_q(state, player, action): GameState.EndTurn: lambda b, p: None } - return Q[state['game_state']](state['board'], player, action) + return Q[board.state](board, player, action) def manual_advantage(state, player, action): - return manual_q(state, player, action) - manual_value(state['board'], player) + return manual_q(state, player, action) - manual_value(state, player)