-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv_util.py
173 lines (151 loc) · 7.33 KB
/
env_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
from typing import Any, Callable, Dict, Optional, Type, Union
import gym
#from stable_baselines3.common.atari_wrappers import AtariWrapper
from atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from language.nl_wrapper import Translearner
from language.task_wrapper import TaskWrapper
from language.tasks import *
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: Environment to unwrap
:param wrapper_class: Wrapper to look for
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
"""
env_tmp = env
while isinstance(env_tmp, gym.Wrapper):
if isinstance(env_tmp, wrapper_class):
return env_tmp
env_tmp = env_tmp.env
return None
def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.
:param env: Environment to check
:param wrapper_class: Wrapper class to look for
:return: True if environment has been wrapped with ``wrapper_class``.
"""
return unwrap_wrapper(env, wrapper_class) is not None
def make_vec_env(
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
:return: The wrapped environment
"""
env_kwargs = {} if env_kwargs is None else env_kwargs
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
def make_env(rank):
def _init():
if isinstance(env_id, str):
env = gym.make(env_id, **env_kwargs)
else:
env = env_id(**env_kwargs)
if seed is not None:
env.seed(seed + rank)
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
# Create the monitor folder if needed
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env, **wrapper_kwargs)
return env
return _init
# No custom VecEnv is passed
if vec_env_cls is None:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
def make_atari_env(
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
parser_args = None,
save_path = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:return: The wrapped environment
"""
if wrapper_kwargs is None:
wrapper_kwargs = {}
task_dict = {0: DownLadderJumpRight, 1: ClimbDownRightLadder, 2: JumpSkullReachLadder,
3: JumpSkullGetKey, 4: ClimbLadderGetKey, 5: ClimbDownGoRightClimbUp, 6: JumpMiddleClimbReachLeftDoor}
def atari_wrapper(env: gym.Env) -> gym.Env:
env = AtariWrapper(env, **wrapper_kwargs)
env = Translearner(env, args=parser_args)
env = TaskWrapper(env, save_path=save_path)
task = task_dict[parser_args.task](env)
env.assign_task(task)
return env
return make_vec_env(
env_id,
n_envs=n_envs,
seed=seed,
start_index=start_index,
monitor_dir=monitor_dir,
wrapper_class=atari_wrapper,
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
monitor_kwargs=monitor_kwargs,
)