Skip to content

Commit

Permalink
update assertion wrapper and add wrappers to __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mahi97 committed Oct 26, 2021
1 parent a463664 commit 8b5c5eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
4 changes: 4 additions & 0 deletions pz_risk/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .assert_invalid_actions import AssertInvalidActionsWrapper
from .vector_observation import VectorObservationWrapper
from .graph_observation import GraphObservationWrapper
from .sparse_reward import SparseRewardWrapper
from .dense_reward import DenseRewardWrapper
20 changes: 11 additions & 9 deletions pz_risk/wrappers/assert_invalid_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ def step(self, action):
elif state == GameState.Card:
assert 0 <= action <= 1, 'Card Action should be 0 or 1: {}'.format(action)
elif state == GameState.Attack:
edges = self.board.player_attack_edges(player)
assert action[0] <= 1, 'Attack Finished should be 0 or 1: {}'.format(action[0])
if action[0] == 0:
assert action[1] in edges, 'Attack Can not be performed from {} to {}'.format(gn(action[1][0]), gn(action[1][1]))
if not action[0]:
edges = self.board.player_attack_edges(player)
assert action[0] <= 1, 'Attack Finished should be 0 or 1: {}'.format(action[0])
if action[0] == 0:
assert action[1] in edges, 'Attack Can not be performed from {} to {}'.format(gn(action[1][0]), gn(action[1][1]))
elif state == GameState.Move:
u = max(0, self.board.g.nodes[self.board.last_attack[1]]['units'] - 3)
assert 0 <= action <= u, 'Move out of bound: {} ~ {}'.format(action, u)
elif state == GameState.Fortify:
cc = self.board.player_connected_components(player)
c = [c for c in cc if action[1] in c][0]
assert 0 <= action[0] <= 1, 'Skip should be 0 or 1: {}'.format(action[0])
assert action[2] in c, 'Fortify Can not be performed from {} to {}'.format(gn(action[1]), gn(action[2]))
assert action[3] <= self.board.g.nodes[action[1]]['units'] ,'Fortify Can not be more than source units!'
if not action[0]:
cc = self.board.player_connected_components(player)
c = [c for c in cc if action[1] in c][0]
assert 0 <= action[0] <= 1, 'Skip should be 0 or 1: {}'.format(action[0])
assert action[2] in c, 'Fortify Can not be performed from {} to {}'.format(gn(action[1]), gn(action[2]))
assert action[3] <= self.board.g.nodes[action[1]]['units'] ,'Fortify Can not be more than source units!'

super().step(action)

Expand Down

0 comments on commit 8b5c5eb

Please sign in to comment.