From db2e346366856a306c6ad2b88cb58475e5dbba31 Mon Sep 17 00:00:00 2001 From: Mohammad Mahdi Rahimi Date: Tue, 26 Oct 2021 17:56:54 +0900 Subject: [PATCH] base update --- pz_risk/core/board.py | 53 ++++++++++++++-- pz_risk/risk_env.py | 87 ++++++++++++++------------ pz_risk/train.py | 142 ++++++++++++++++++++++++++++++++++++++++++ pz_risk/utils.py | 39 ++++++++++++ test.py | 53 ++++++++++++++++ 5 files changed, 329 insertions(+), 45 deletions(-) create mode 100644 pz_risk/train.py create mode 100644 test.py diff --git a/pz_risk/core/board.py b/pz_risk/core/board.py index 34067ed..4d23c89 100644 --- a/pz_risk/core/board.py +++ b/pz_risk/core/board.py @@ -22,6 +22,39 @@ def __init__(self, graph: nx.Graph, info, pos=None): self.last_attack = (None, None) self.state = GameState.StartTurn self.info = info + self.n_grps = info['num_of_groups'] + self.n_cards = self.g.number_of_nodes() + self.info['num_of_wild'] + + def valid_actions(self, player): + """ + :player + return: + bool: isDeterministic? + list: action list + """ + acts = [] + if self.state == GameState.Reinforce: + acts = self.player_nodes(player) + elif self.state == GameState.Card: + acts = [0, 1] if len(self.players[player].cards) < 5 else [1] + elif self.state == GameState.Attack: + edges = self.player_attack_edges(player) + acts = [(1, (None, None))] + # assert action[0] <= 1, 'Attack Finished should be 0 or 1: {}'.format(action[0]) + acts += [(0, e) for e in edges] + elif self.state == GameState.Move: + u = max(0, self.g.nodes[self.last_attack[1]]['units'] - 3) + acts = [i for i in range(u+1)] + elif self.state == GameState.Fortify: + cc = self.player_connected_components(player) + acts = [(1, None, None, None)] + for c in cc: + for a in c: + for b in c: + if a != b and self.g.nodes[a]['units'] > 1: + acts.append((0, a, b, self.g.nodes[a]['units'] - 1)) + + return self.state != GameState.Attack, acts def can_fortify(self, player): cc = self.player_connected_components(player) @@ -94,7 +127,7 @@ def player_units(self, player): return sum([n[1]['units'] for n in self.g.nodes(data=True) if n[1]['player'] == player]) def player_group_reward(self, player): - group = {gid + 1: True for gid in range(self.info['num_of_group'])} + group = {gid + 1: True for gid in range(self.n_grps)} for n in self.g.nodes(data=True): if n[1]['player'] != player: group[n[1]['gid']] = False @@ -128,8 +161,10 @@ def player_attack_edges(self, player): ee.append((e[1], e[0])) return ee - def reset(self, n_agent, n_unit_per_agent, n_cell_per_agent): + def reset(self, n_agent): n_cells = self.g.number_of_nodes() + n_cell_per_agent = n_cells // n_agent + n_unit_per_agent = self.info['num_of_unit'] assert n_cell_per_agent * n_agent == n_cells remaining_cells = [i for i in self.g.nodes()] @@ -178,10 +213,11 @@ def apply_best_match(self, player): if cnt == 3: break else: - match_type = CardType.Artillery if ct[CardType.Artillery] >= 3 - cnt\ - else CardType.Cavalry if ct[CardType.Cavalry] >= 3 - cnt\ + match_type = CardType.Artillery if ct[CardType.Artillery] >= 3 - cnt \ + else CardType.Cavalry if ct[CardType.Cavalry] >= 3 - cnt \ else CardType.Infantry - used += [self.players[player].cards[match_type].pop(-1) for _ in range(3 - cnt) if len(self.players[player].cards[match_type])] + used += [self.players[player].cards[match_type].pop(-1) for _ in range(3 - cnt) if + len(self.players[player].cards[match_type])] self.players[player].placement += CARD_FIX_SCORE[match_type] for c in used: c.owner = -1 @@ -245,7 +281,8 @@ def step(self, agent, actions, left=None): self.g.nodes[actions[1]]['units'] -= int(actions[3]) self.g.nodes[actions[2]]['units'] += int(actions[3]) - self.next_state(agent, self.state, attack_succeed, attack_finished, len(self.player_nodes(agent)) == len(self.g.nodes())) + self.next_state(agent, self.state, attack_succeed, attack_finished, + len(self.player_nodes(agent)) == len(self.g.nodes())) if self.state == GameState.StartTurn and self.players[agent].deserve_card: self.give_card(agent) @@ -261,5 +298,9 @@ def register_map(name, filepath): BOARDS[name] = Board(g, m['info']) + print(os.getcwd()) register_map('world', './maps/world.json') +register_map('4node', './maps/4node.json') +register_map('6node', './maps/6node.json') +register_map('8node', './maps/8node.json') diff --git a/pz_risk/risk_env.py b/pz_risk/risk_env.py index a7bd924..26cfebd 100644 --- a/pz_risk/risk_env.py +++ b/pz_risk/risk_env.py @@ -1,7 +1,7 @@ import math import random -from gym.spaces import Discrete, MultiDiscrete, Dict +from gym.spaces import Discrete, MultiDiscrete, Dict, MultiBinary, Box, Tuple from pettingzoo import AECEnv from pettingzoo.utils import agent_selector from pettingzoo.utils import wrappers @@ -16,6 +16,7 @@ from core.gamestate import GameState from loguru import logger +from copy import deepcopy from utils import * from agents.sampling import SAMPLING @@ -33,7 +34,7 @@ ] -def env(): +def env(n_agent=6, board_name='world'): """ The env function wraps the environment in 3 wrappers by default. These wrappers contain logic that is common to many pettingzoo environments. @@ -41,7 +42,7 @@ def env(): to provide sane error messages. You can find full documentation for these methods elsewhere in the developer documentation. """ - env = RiskEnv() + env = RiskEnv(n_agent, board_name) env = wrappers.CaptureStdoutWrapper(env) env = risk_wrappers.AssertInvalidActionsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -58,7 +59,7 @@ class RiskEnv(AECEnv): metadata = {'render.modes': ['human'], "name": "rps_v2"} def __init__(self, n_agent=6, board_name='world'): - ''' + """ - n_agent: Number of Agent - board: ['test', 'world', 'world2'] The init method takes in environment arguments and @@ -68,27 +69,29 @@ def __init__(self, n_agent=6, board_name='world'): - observation_spaces These attributes should not be changed after initialization. - ''' + """ super().__init__() self.board = BOARDS[board_name] - n_nodes = self.board.g.number_of_nodes() - n_edges = self.board.g.number_of_edges() + self.n_nodes = self.board.g.number_of_nodes() + self.n_edges = self.board.g.number_of_edges() + self.n_grps = self.board.n_grps + self.n_cards = self.board.n_cards + self.n_agents = n_agent self.possible_agents = [r for r in range(n_agent)] self.agent_name_mapping = dict(zip(self.possible_agents, list(range(len(self.possible_agents))))) # Gym spaces are defined and documented here: https://gym.openai.com/docs/#spaces - self.action_spaces = {agent: {GameState.Reinforce: Discrete(n_nodes), - GameState.Attack: MultiDiscrete([2, n_edges]), # +1 for Skip - GameState.Fortify: MultiDiscrete([2, n_nodes, n_nodes, 100]), # Last dim for Skip - GameState.StartTurn: Discrete(1), - GameState.EndTurn: Discrete(1), - GameState.Card: Discrete(2), - GameState.Move: Discrete(100) - } for agent in self.possible_agents} - self.observation_spaces = {agent: Discrete(MAX_UNIT) for agent in self.possible_agents} # placement - self.observation_spaces['board'] = Dict({}) - self.observation_spaces['cards'] = MultiDiscrete([MAX_CARD for _ in range(n_agent)]) - self.observation_spaces['my_cards'] = Discrete(2) + self.action_spaces = {GameState.Reinforce: Discrete(self.n_nodes), + GameState.Attack: MultiDiscrete([2, self.n_edges]), # +1 for Skip + GameState.Fortify: MultiDiscrete([2, self.n_nodes, self.n_nodes, 100]), + # Last dim for Skip + # GameState.StartTurn: Discrete(1), + # GameState.EndTurn: Discrete(1), + GameState.Card: Discrete(2), + GameState.Move: Discrete(100) + } + # self.action_spaces = Box(0, 1000, shape=[self.n_nodes + self.n_edges + self.n_nodes+self.n_nodes + 100 + 1+1+1]) + self.observation_spaces = None # Core.Board() self.agents = [] self.rewards = {} @@ -136,10 +139,10 @@ def render_info(self, mode="human"): plt.pause(0.001) def render(self, mode="human"): - ''' + """ Renders the environment. In human mode, it can print to terminal, open up a graphical window, or open up some other display that a human can see and understand. - ''' + """ plt.figure(0) plt.clf() @@ -169,18 +172,13 @@ def render(self, mode="human"): print('Wait for it') def observe(self, agent): - ''' + """ Observe should return the observation of the specified agent. This function should return a sane observation (though not necessarily the most up to date possible) at any time after reset() is called. - ''' - # observation of one agent is the previous state of the other + """ - return {'board': self.board, - 'my_card': self.board.players[agent].cards, - 'placement': self.board.players[agent].placement, - 'game_state': self.board.state, - 'cards': [len(p.cards) for p in self.board.players]} + return self.board def close(self): """ @@ -208,8 +206,8 @@ def reset(self): self.rewards = {agent: 0 for agent in self.agents} self._cumulative_rewards = {agent: 0 for agent in self.agents} self.dones = {agent: False for agent in self.agents} - self.infos = {agent: {} for agent in self.agents} - self.board.reset(len(self.agents), 20, 7) + self.infos = {agent: {'nodes': self.n_nodes, 'agents': self.n_agents} for agent in self.agents} + self.board.reset(len(self.agents)) self.num_turns = 0 self.num_moves = 1 ''' @@ -218,12 +216,20 @@ def reset(self): self._agent_selector = agent_selector(self.agents) self.agent_selection = self._agent_selector.next() + self.land_hist = {a: [] for a in self.possible_agents} + self.unit_hist = {a: [] for a in self.possible_agents} + self.place_hist = {a: [] for a in self.possible_agents} + def reward(self, agent): return 0.0 def done(self, agent): - return False + return len(self.board.player_nodes(agent)) == 0 + # def get_action(self, action): + # # [self.n_nodes + self.n_edges + self.n_nodes + self.n_nodes + 100 + 1 + 1 + 1] + # if self.board.state == GameState.Reinforce: + # action def step(self, action): """ step(action) takes in an action for the current agent (specified by @@ -243,9 +249,11 @@ def step(self, action): agent = self.agent_selection state = self.board.state - logger.info('Player: {}, State: {}, Actions: {}'.format(agent, state, action)) + # logger.info('Player: {}, State: {}, Actions: {}'.format(agent, state, action)) self._cumulative_rewards[agent] = 0 + # if len(action) == self.action_spaces.shape[0]: + # action = self.get_action(action) self.board.step(agent, action) @@ -274,15 +282,16 @@ def step(self, action): if self.board.state == GameState.EndTurn: self.dones = {agent: True for agent in self.agents} # Adds .rewards to ._cumulative_rewards - self._accumulate_rewards() + # self._accumulate_rewards() if __name__ == '__main__': - e = env() + e = env(2, 'world') e.reset() # e.render() winner = -1 - for agent in e.agent_iter(): + for i, agent in enumerate(e.agent_iter()): + print(i) obs, rew, done, info = e.last() if done: continue @@ -294,8 +303,8 @@ def step(self, action): if all(e.dones.values()): winner = agent break - # e.render() - # e.render() - # plt.show() + e.render() + e.render() + plt.show() logger.info('Done in {} Turns and {} Moves. Winner is Player {}' .format(e.unwrapped.num_turns, e.unwrapped.num_moves, winner)) diff --git a/pz_risk/train.py b/pz_risk/train.py new file mode 100644 index 0000000..7903982 --- /dev/null +++ b/pz_risk/train.py @@ -0,0 +1,142 @@ +import os +import time +from collections import deque + +import numpy as np +import torch + +import training.utils as utils +from training.envs import make_vec_envs +from training.model import Policy +from training.storage import RolloutStorage +from training.arguments import get_args +from training.evaluation import evaluate +from training.ppo import PPO + +from risk_env import env +from pz_risk.wrappers import GraphObservationWrapper + + +def main(): + args = get_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + log_dir = os.path.expanduser(args.log_dir) + eval_log_dir = log_dir + "_eval" + utils.cleanup_log_dir(log_dir) + utils.cleanup_log_dir(eval_log_dir) + + torch.set_num_threads(1) + device = torch.device("cuda:0" if args.cuda else "cpu") + + e = env() + e = GraphObservationWrapper(e) + e.reset() + actor_critic = Policy( + e.observation_spaces, + e.action_spaces) + actor_critic.to(device) + + agent = PPO( + actor_critic, + args.clip_param, + args.ppo_epoch, + args.num_mini_batch, + args.value_loss_coef, + args.entropy_coef, + lr=args.lr, + eps=args.eps, + max_grad_norm=args.max_grad_norm) + + rollouts = RolloutStorage(args.num_steps, args.num_processes, + e.observation_spaces['feat'].shape, e.action_spaces, + e.observation_spaces['task_id'].shape) + + obs, _, _, _ = e.last() + rollouts.obs[0].copy_(obs) + rollouts.to(device) + + episode_rewards = deque(maxlen=10) + + start = time.time() + num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes + for j in range(num_updates): + + if args.use_linear_lr_decay: + # decrease learning rate linearly + utils.update_linear_schedule(agent.optimizer, j, num_updates, args.lr) + + for step in range(args.num_steps): + # Sample actions + with torch.no_grad(): + value, action, action_log_prob = actor_critic.act(rollouts.obs[step], rollouts.task_id[step], + rollouts.masks[step]) + + # Observe reward and next obs + obs, reward, done, infos = e.step(action) + + for info in infos: + if 'episode' in info.keys(): + episode_rewards.append(info['episode']['r']) + + # If done then clean the history of observations. + masks = torch.FloatTensor( + [[0.0] if done_ else [1.0] for done_ in done]) + bad_masks = torch.FloatTensor( + [[0.0] if 'bad_transition' in info.keys() else [1.0] + for info in infos]) + rollouts.insert(obs['feat'], obs['task_id'], action, + action_log_prob, value, reward, masks, bad_masks) + + with torch.no_grad(): + next_value = actor_critic.get_value( + rollouts.obs[-1], rollouts.task_id[-1], + rollouts.masks[-1]).detach() + + rollouts.compute_returns(next_value, args.gamma, args.use_proper_time_limits) + + value_loss, action_loss, dist_entropy = agent.update(rollouts) + + rollouts.after_update() + + # save for every interval-th episode or for the last epoch + if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": + save_path = os.path.join(args.save_dir, args.algo) + try: + os.makedirs(save_path) + except OSError: + pass + + torch.save([ + actor_critic, + getattr(utils.get_vec_normalize(e), 'obs_rms', None) + ], os.path.join(save_path, args.env_name + ".pt")) + + if j % args.log_interval == 0 and len(episode_rewards) > 1: + total_num_steps = (j + 1) * args.num_processes * args.num_steps + end = time.time() + print( + "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}," + " min/max reward {:.1f}/{:.1f}\n" + .format(j, total_num_steps, + int(total_num_steps / (end - start)), + len(episode_rewards), np.mean(episode_rewards), + np.median(episode_rewards), np.min(episode_rewards), + np.max(episode_rewards), dist_entropy, value_loss, + action_loss)) + + if (args.eval_interval is not None and len(episode_rewards) > 1 + and j % args.eval_interval == 0): + obs_rms = utils.get_vec_normalize(e).obs_rms + evaluate(actor_critic, obs_rms, args.env_name, args.seed, + args.num_processes, eval_log_dir, device) + + +if __name__ == "__main__": + main() diff --git a/pz_risk/utils.py b/pz_risk/utils.py index 55e5ff4..f949b8c 100644 --- a/pz_risk/utils.py +++ b/pz_risk/utils.py @@ -5,6 +5,9 @@ import numpy as np from collections import Iterable +from copy import deepcopy +import networkx as nx + rng = np.random.default_rng() sided_die = 6 attack_max = 3 @@ -23,6 +26,10 @@ def single_roll(attack: int, defend: int) -> (int, int): return attack_loss, defend_loss +def to_one_hot(num, maximum): + return [1 if i == num else 0 for i in range(maximum)] + + def flatten(lis): for item in lis: if isinstance(item, Iterable) and not isinstance(item, str): @@ -30,3 +37,35 @@ def flatten(lis): yield x else: yield item + + +def get_feat_adj_from_board(board, player, n_agents, n_grps): + feats = [] + for node in board.g.nodes(data=True): + feats.append([ + 0, + node[1]['units'], + node[1]['gid'], + board.info['group_reward'][str(node[1]['gid'])], + -1, -1 # Player + ]) + # feat_player = [] + for p in range(n_agents): + feats.append([ + -1, 0, -1, 0, + 1 if board.can_card(p) else 0, + 1 if p == player else 0 + ]) + temp_g = nx.Graph(deepcopy(board.g)) + temp_g.add_nodes_from(['p' + str(i) for i in range(n_agents)]) + edges = [[node[0], 'p' + str(node[1]['player'])] for node in temp_g.nodes(data=True) if 'player' in node[1]] + + # edges = [] + # for i in range(n_agents): + # edges += [[node[0], 'p' + str(i)] for node in temp_g.nodes(data=True) if 'player' in node[1]] + temp_g.add_edges_from(edges) + + feats = [list(flatten([to_one_hot(feat[0], 2), feat[1], to_one_hot(feat[2], n_grps), feat[3], + to_one_hot(feat[4], 2), to_one_hot(feat[5], 2)])) for feat in feats] + adj = nx.adjacency_matrix(temp_g).todense() + np.eye(len(temp_g.nodes())) + return feats, adj diff --git a/test.py b/test.py new file mode 100644 index 0000000..23bc74e --- /dev/null +++ b/test.py @@ -0,0 +1,53 @@ +import numpy as np +import matplotlib.pyplot as plt + +win_rate = np.array([ + [[0.417, 0.583, 0.], # 1 vs 1 + [0.255, 0.745, 0.]], # 1 vs 2 + [[0.579, 0.421, 0.], # 2 vs 1 + [0.228, 0.324, 0.448]], # 2 vs 2 + [[0.660, 0.340, 0.], # 3 vs 1 + [0.371, 0.336, 0.293]] # 3 vs 2 +]) + +d3 = {} + +def get_chance(attack_unit, defense_unit, left): + global win_rate, d3 + i_a = min(attack_unit - 1, 2) + i_d = min(defense_unit - 1, 1) + if (attack_unit, defense_unit, left) in d3: + c = d3[(attack_unit, defense_unit, left)] + return c + + c = 0.0 + if left < -defense_unit or left > attack_unit: + c = 0.0 + elif defense_unit < 0 or attack_unit < 0: + c = 0.0 + elif attack_unit == 0: + if left == -defense_unit: + c = 1.0 + else: + c = 0.0 + elif defense_unit == 0: + if left == attack_unit: + c = 1.0 + else: + c = 0.0 + else: + c = win_rate[i_a, i_d, 0] * get_chance(attack_unit, defense_unit - min(min(i_a, i_d) + 1, 2), left) + \ + win_rate[i_a, i_d, 1] * get_chance(attack_unit - 1, defense_unit - min(i_a, 1), left) + \ + win_rate[i_a, i_d, 2] * get_chance(attack_unit - 2, defense_unit, left) + d3[(attack_unit, defense_unit, left)] = c + return c + +k, j = 8, 8 + + +b = range(-j, k) +a = [get_chance(k, j, i) for i in b] + + +plt.plot(b, a) +plt.show() \ No newline at end of file