Skip to content

Commit

Permalink
add model agent and update others
Browse files Browse the repository at this point in the history
  • Loading branch information
mahi97 committed Oct 26, 2021
1 parent 5513a2e commit 375b37b
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 42 deletions.
6 changes: 4 additions & 2 deletions pz_risk/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import BaseAgent
from .greedy import GreedyAgent
from .random import RandomAgent

from .model import ModelAgent
from .value import warm_up
warm_up()
print('Start Warmp up...')
warm_up()
print('Warm up Done.')
28 changes: 18 additions & 10 deletions pz_risk/agents/greedy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np

from core.board import Board
from agents.base import BaseAgent
from core.gamestate import GameState
from agents.sampling import SAMPLING
from agents.value import manual_advantage, manual_q

from loguru import logger


class GreedyAgent(BaseAgent):
def __init__(self, player_id):
super(GreedyAgent, self).__init__()
Expand All @@ -15,15 +17,21 @@ def __init__(self, player_id):
def reset(self):
pass

def act(self, state):
if state['game_state'] == GameState.Attack:
attack_edge = state['board'].player_attack_edges(self.player_id)
base = manual_advantage(state, self.player_id, (1, None)) # Attack Finished
v = [manual_advantage(state, self.player_id, (False, ae)) for ae in attack_edge]
edge = attack_edge[np.argmax(v)]
logger.info('Attack values:{}, base: {}'.format(v, base))
return (1, (None, None)) if base > max(v) else (0, edge)
else:
return SAMPLING[state['game_state']](state['board'], self.player_id)
def act(self, state: Board):
# if state.state == GameState.Attack:
# attack_edge = state.player_attack_edges(self.player_id)
# base = manual_advantage(state, self.player_id, (1, None)) # Attack Finished
# v = [manual_advantage(state, self.player_id, (False, ae)) for ae in attack_edge]
# edge = attack_edge[np.argmax(v)]
# # logger.info('Attack values:{}, base: {}'.format(v, base))
# return (1, (None, None)) if base > max(v) else (0, edge)
# else:
# Use Model to Gather Future State per Valid Actions
action_scores = []
deterministic, valid_actions = state.valid_actions(self.player_id)
for valid_action in valid_actions:
action_scores.append(manual_advantage(state, self.player_id, valid_action))
action = valid_actions[np.argmax(action_scores)]
return action
# Find the action with highest advantage
# Execute the action
53 changes: 53 additions & 0 deletions pz_risk/agents/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import numpy as np

from core.board import Board
from agents.base import BaseAgent

import torch
from copy import deepcopy
from agents.value import get_future, get_attack_dist
from utils import get_feat_adj_from_board
from training.dvn import DVNAgent


class ModelAgent(BaseAgent):
def __init__(self, player_id, device='cuda:0'):
super(ModelAgent, self).__init__()
self.player_id = player_id
self.device = device
feat_size = 14 # e.observation_spaces['feat'].shape[0]
hidden_size = 20

self.critic = DVNAgent(feat_size, hidden_size)
save_path = './trained_models4/'
load = 8
self.critic.load_state_dict(torch.load(os.path.join(save_path, str(load) + ".pt")))
self.critic.eval()

# feat = torch.tensor(state['feat'], dtype=torch.float32, device=device).reshape(-1, 48, feat_size)
# adj = torch.tensor(state['adj'], dtype=torch.float32, device=device).reshape(-1, 48, 48)

def reset(self):
pass

def act(self, state: Board):
action_scores = []
deterministic, valid_actions = state.valid_actions(self.player_id)
for valid_action in valid_actions:
sim = deepcopy(state)
if deterministic:
sim.step(self.player_id, valid_action)
else:
dist = get_attack_dist(state, valid_action)
if len(dist): # TODO: Change to sampling
left = get_future(dist, mode='most')
sim.step(self.player_id, valid_action, left)
else:
sim.step(self.player_id, valid_action)
sim_feat, sim_adj = get_feat_adj_from_board(sim, self.player_id, 6, 6)
sim_feat = torch.tensor(sim_feat, dtype=torch.float32, device=self.device).reshape(-1, 48, 14)
sim_adj = torch.tensor(sim_adj, dtype=torch.float32, device=self.device).reshape(-1, 48, 48)
action_scores.append(self.critic(sim_feat, sim_adj).detach().cpu().numpy()[:, 42 + self.player_id])
action = valid_actions[np.argmax(action_scores)]
return action
2 changes: 1 addition & 1 deletion pz_risk/agents/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def reset(self):
pass

def act(self, state):
return SAMPLING[state['game_state']](state['board'], self.player_id)
return SAMPLING[state.state](state, self.player_id)
10 changes: 0 additions & 10 deletions pz_risk/agents/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,7 @@


def sample_reinforce(board, player):
# num_units = board.players[player].placement
nodes = board.player_nodes(player)
# r2 = {n: 0 for n in nodes}
# for _ in range(num_units):
# i = np.random.choice(nodes)
# r2[i] += 1
# branches = math.comb(num_units + len(nodes) - 1, num_units)
# index = np.random.randint(branches)
# r = sorted([(index // num_units ** i) % num_units for i in range(len(nodes))])
# r.append(num_units)
# r2 = {n: r[i + 1] - r[i] for i, n in zip(range(len(nodes)), random.sample(nodes, len(nodes)))}
return np.random.choice(nodes) # [(0 if n not in nodes else r2[n]) for n in board.g.nodes()]


Expand Down
50 changes: 31 additions & 19 deletions pz_risk/agents/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

# From https://web.stanford.edu/~guertin/risk.notes.html
win_rate = np.array([
[[0.417, 0.583, 0.], # 1 vs 1
[0.255, 0.745, 0.]], # 1 vs 2
[[0.579, 0.421, 0.], # 2 vs 1
[0.228, 0.324, 0.448]], # 2 vs 2
[[0.660, 0.340, 0.], # 3 vs 1
[0.371, 0.336, 0.293]] # 3 vs 2
[[0.417, 0.583, 0.], # 1 vs 1
[0.255, 0.745, 0.]], # 1 vs 2
[[0.579, 0.421, 0.], # 2 vs 1
[0.228, 0.324, 0.448]], # 2 vs 2
[[0.660, 0.340, 0.], # 3 vs 1
[0.371, 0.336, 0.293]] # 3 vs 2
])

d3 = {}
Expand Down Expand Up @@ -69,9 +69,9 @@ def get_future(dist, mode='safe', risk=0.0):
max_index = np.argmax([d[1] for d in dist])
return dist[max_index][0]
elif mode == 'two':
left = [d[1] for d in dist if d[1] < 0]
left = [d[1] for d in dist if d[1] <= 0]
right = [d[1] for d in dist if d[1] > 0]
left_score = dist[np.argmax(left)][0] * sum(left)
left_score = dist[np.argmax(left)][0] * (sum(left) + risk)
right_score = dist[np.argmax(right)][0] * sum(right)
return left_score + right_score
elif mode == 'all':
Expand All @@ -80,11 +80,14 @@ def get_future(dist, mode='safe', risk=0.0):


def manual_value(board, player):
num_lands = len(board.player_nodes(player)) * 5
num_units = board.player_units(player)
group_reward = board.player_group_reward(player) * 10
my_lands = len(board.player_nodes(player))
opp_lands = max([len(board.player_nodes(p)) for p in range(6) if p != player])
my_units = board.player_units(player)
opp_units = max([board.player_units(p) for p in range(6) if p != player])
my_group_reward = board.player_group_reward(player)
opp_group_reward = max([board.player_group_reward(p) for p in range(6) if p != player])
num_cards = sum([len(c) for c in board.players[player].cards.values()])
return num_lands + num_units + group_reward + num_cards
return my_units + my_lands * 3 + my_group_reward * 10 + num_cards # (my_lands - opp_lands) + (my_group_reward - opp_group_reward)*2


def man_q_deterministic(board, player, action):
Expand All @@ -93,21 +96,30 @@ def man_q_deterministic(board, player, action):
return manual_value(sim, player)


def man_q_attack(board, player, action):
def get_attack_dist(board, action):
attack_finished = action[0]
sim = deepcopy(board)
dist = []
if not attack_finished:
src = action[1][0]
trg = action[1][1]
src_unit = board.g.nodes[src]['units']
trg_unit = board.g.nodes[trg]['units']
dist = [(i, get_chance(src_unit, trg_unit, i)) for i in range(-trg_unit, src_unit+1)]
left = get_future(dist, mode='most')
dist = [(i, get_chance(src_unit, trg_unit, i)) for i in range(-trg_unit, src_unit + 1)]
return dist


def man_q_attack(board, player, action):
sim = deepcopy(board)
dist = get_attack_dist(board, action)
if len(dist) > 0:
left = get_future(dist, mode='all')
sim.step(player, action, left)
else:
sim.step(player, action)
return manual_value(sim, player)


def manual_q(state, player, action):
def manual_q(board, player, action):
Q = {
GameState.Reinforce: man_q_deterministic,
GameState.Attack: man_q_attack,
Expand All @@ -117,8 +129,8 @@ def manual_q(state, player, action):
GameState.EndTurn: lambda b, p: None
}

return Q[state['game_state']](state['board'], player, action)
return Q[board.state](board, player, action)


def manual_advantage(state, player, action):
return manual_q(state, player, action) - manual_value(state['board'], player)
return manual_q(state, player, action) - manual_value(state, player)

0 comments on commit 375b37b

Please sign in to comment.