-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathray_mlp.py
56 lines (43 loc) · 1.41 KB
/
ray_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Any, Dict, List, Tuple
import gymnasium as gym
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from torch import nn
from popgym.baselines.ray_models.base_model import BaseModel
class MLP(BaseModel):
"""A good old MLP that has no memory whatsoever.
Useful to see if your memory model is actually using its memory."""
MODEL_CONFIG: Dict[str, Any] = {
"embedding": "sine",
}
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
**custom_model_kwargs,
):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self.core = nn.Sequential(
nn.Linear(self.cfg["preprocessor_output_size"], self.cfg["hidden_size"]),
nn.LeakyReLU(),
nn.Linear(self.cfg["hidden_size"], self.cfg["hidden_size"]),
nn.LeakyReLU(),
)
def initial_state(self) -> List[TensorType]:
return []
def forward_memory(
self,
z: TensorType,
state: List[TensorType],
t_starts: TensorType,
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
z = self.core(z)
# State expected to be list
return z, []
class BasicMLP(MLP):
MODEL_CONFIG = {
"embedding": None,
}