forked from NTT123/a0-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv.py
61 lines (43 loc) · 1.58 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Enviroment base class.
"""
from typing import Any, Tuple, TypeVar
import chex
import pax
E = TypeVar("E")
class Enviroment(pax.Module):
"""A template for environments."""
def __init__(self):
super().__init__()
def step(self: E, action: chex.Array) -> Tuple[E, chex.Array]:
"""A single env step."""
raise NotImplementedError()
def reset(self):
"""Reset the enviroment."""
def is_terminated(self) -> chex.Array:
"""The env is terminated."""
raise NotImplementedError()
def observation(self) -> Any:
"""The observation from env."""
def canonical_observation(self) -> Any:
"""Return the canonical observation."""
def num_actions(self) -> int:
"""Return the size of the action space."""
raise NotImplementedError()
def invalid_actions(self) -> chex.Array:
"""An boolean array indicating invalid actions.
Returns:
invalid_action: the i-th element is true if action `i` is invalid. [num_actions].
"""
raise NotImplementedError()
def max_num_steps(self) -> int:
"""Return the maximum number of steps until the game is terminated."""
raise NotImplementedError()
def parse_action(self, action_str: str) -> int:
"""Parse a string action and return a number."""
return int(action_str)
def symmetries(self, state, action_weights):
"""The default symmetry group is the identity.
We use this method for data augmentation in training.
"""
return [(state, action_weights)]