diff --git a/pz_risk/wrappers/__init__.py b/pz_risk/wrappers/__init__.py index e690d9a..7b53e06 100644 --- a/pz_risk/wrappers/__init__.py +++ b/pz_risk/wrappers/__init__.py @@ -1 +1,5 @@ from .assert_invalid_actions import AssertInvalidActionsWrapper +from .vector_observation import VectorObservationWrapper +from .graph_observation import GraphObservationWrapper +from .sparse_reward import SparseRewardWrapper +from .dense_reward import DenseRewardWrapper diff --git a/pz_risk/wrappers/assert_invalid_actions.py b/pz_risk/wrappers/assert_invalid_actions.py index 23c8542..9b6daab 100644 --- a/pz_risk/wrappers/assert_invalid_actions.py +++ b/pz_risk/wrappers/assert_invalid_actions.py @@ -26,19 +26,21 @@ def step(self, action): elif state == GameState.Card: assert 0 <= action <= 1, 'Card Action should be 0 or 1: {}'.format(action) elif state == GameState.Attack: - edges = self.board.player_attack_edges(player) - assert action[0] <= 1, 'Attack Finished should be 0 or 1: {}'.format(action[0]) - if action[0] == 0: - assert action[1] in edges, 'Attack Can not be performed from {} to {}'.format(gn(action[1][0]), gn(action[1][1])) + if not action[0]: + edges = self.board.player_attack_edges(player) + assert action[0] <= 1, 'Attack Finished should be 0 or 1: {}'.format(action[0]) + if action[0] == 0: + assert action[1] in edges, 'Attack Can not be performed from {} to {}'.format(gn(action[1][0]), gn(action[1][1])) elif state == GameState.Move: u = max(0, self.board.g.nodes[self.board.last_attack[1]]['units'] - 3) assert 0 <= action <= u, 'Move out of bound: {} ~ {}'.format(action, u) elif state == GameState.Fortify: - cc = self.board.player_connected_components(player) - c = [c for c in cc if action[1] in c][0] - assert 0 <= action[0] <= 1, 'Skip should be 0 or 1: {}'.format(action[0]) - assert action[2] in c, 'Fortify Can not be performed from {} to {}'.format(gn(action[1]), gn(action[2])) - assert action[3] <= self.board.g.nodes[action[1]]['units'] ,'Fortify Can not be more than source units!' + if not action[0]: + cc = self.board.player_connected_components(player) + c = [c for c in cc if action[1] in c][0] + assert 0 <= action[0] <= 1, 'Skip should be 0 or 1: {}'.format(action[0]) + assert action[2] in c, 'Fortify Can not be performed from {} to {}'.format(gn(action[1]), gn(action[2])) + assert action[3] <= self.board.g.nodes[action[1]]['units'] ,'Fortify Can not be more than source units!' super().step(action)