Skip to content

Commit

Permalink
[Feature] Parallel collection (#152)
Browse files Browse the repository at this point in the history
* parallel collection

* parallel collection

* parallel collection

* fixes

* fixes

* fixes

* revert buffer update
  • Loading branch information
matteobettini authored Dec 18, 2024
1 parent 271edee commit e910a83
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 73 deletions.
42 changes: 1 addition & 41 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
Categorical,
Composite,
LazyTensorStorage,
OneHot,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.envs import (
Compose,
EnvBase,
InitTracker,
TensorDictPrimer,
Transform,
TransformedEnv,
)
from torchrl.envs import Compose, EnvBase, Transform
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -251,38 +243,6 @@ def process_env_fun(
Returns: a function that takes no args and creates an enviornment
"""
if self.has_rnn:

def model_fun():
env = env_fun()

spec_actor = self.model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in self.group_map.items()
}
)

env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return env

return model_fun

return env_fun

Expand Down
7 changes: 5 additions & 2 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ share_policy_params: True
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False
# In case of non-vectorized environments, weather to run collection of multiple processes
# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each
parallel_collection: False

# Discount factor
gamma: 0.9
Expand Down Expand Up @@ -51,7 +54,7 @@ max_n_frames: 3_000_000
on_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
on_policy_n_envs_per_worker: 10
# This is the number of times collected_frames_per_batch will be split into minibatches and trained
on_policy_n_minibatch_iters: 45
Expand All @@ -63,7 +66,7 @@ on_policy_minibatch_size: 400
off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
off_policy_n_envs_per_worker: 10
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _type_check_task_config(
else:
if warn_on_missing_dataclass:
warnings.warn(
"TaskConfig python dataclass not foud, task is being loaded without type checks"
"TaskConfig python dataclass not found, task is being loaded without type checks"
)
return config

Expand Down
9 changes: 5 additions & 4 deletions benchmarl/environments/magent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -31,17 +31,18 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)

return lambda: PettingZooWrapper(
env=self.__get_env(),
env=self.__get_env(config),
return_state=True,
seed=seed,
done_on_any=False,
use_mask=False,
device=device,
)

def __get_env(self) -> EnvBase:
def __get_env(self, config) -> EnvBase:
try:
from magent2.environments import (
adversarial_pursuit_v4,
Expand All @@ -66,7 +67,7 @@ def __get_env(self) -> EnvBase:
}
if self.name not in envs:
raise Exception(f"{self.name} is not an environment of MAgent2")
return envs[self.name].parallel_env(**self.config, render_mode="rgb_array")
return envs[self.name].parallel_env(**config, render_mode="rgb_array")

def supports_continuous_actions(self) -> bool:
return False
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/environments/meltingpot/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -84,11 +84,13 @@ def get_env_fun(
) -> Callable[[], EnvBase]:
from torchrl.envs.libs.meltingpot import MeltingpotEnv

config = copy.deepcopy(self.config)

return lambda: MeltingpotEnv(
substrate=self.name.lower(),
categorical_actions=True,
device=device,
**self.config,
**config,
)

def supports_continuous_actions(self) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions benchmarl/environments/pettingzoo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -35,17 +36,17 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
if self.supports_continuous_actions() and self.supports_discrete_actions():
self.config.update({"continuous_actions": continuous_actions})

config.update({"continuous_actions": continuous_actions})
return lambda: PettingZooEnv(
categorical_actions=True,
device=device,
seed=seed,
parallel=True,
return_state=self.has_state(),
render_mode="rgb_array",
**self.config
**config
)

def supports_continuous_actions(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/environments/smacv2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -42,8 +42,9 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
return lambda: SMACv2Env(
categorical_actions=True, seed=seed, device=device, **self.config
categorical_actions=True, seed=seed, device=device, **config
)

def supports_continuous_actions(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -52,6 +52,7 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
return lambda: VmasEnv(
scenario=self.name.lower(),
num_envs=num_envs,
Expand All @@ -60,7 +61,7 @@ def get_env_fun(
device=device,
categorical_actions=True,
clamp_actions=True,
**self.config,
**config,
)

def supports_continuous_actions(self) -> bool:
Expand Down
42 changes: 29 additions & 13 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path

from typing import Any, Dict, List, Optional

import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector
from torchrl.envs import SerialEnv, TransformedEnv

from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
Expand All @@ -34,7 +36,7 @@
from benchmarl.experiment.logger import Logger
from benchmarl.models import GnnConfig, SequenceModelConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import _read_yaml_config, seed_everything
from benchmarl.utils import _add_rnn_transforms, _read_yaml_config, seed_everything

_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
Expand All @@ -58,6 +60,7 @@ class ExperimentConfig:
share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING
parallel_collection: bool = MISSING

gamma: float = MISSING
lr: float = MISSING
Expand Down Expand Up @@ -430,20 +433,10 @@ def _setup_task(self):
transforms_training = transforms_env + [
self.task.get_reward_sum_transform(test_env)
]

transforms_env = Compose(*transforms_env)
transforms_training = Compose(*transforms_training)

if test_env.batch_size == ():
self.env_func = lambda: TransformedEnv(
SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func),
transforms_training.clone(),
)
else:
self.env_func = lambda: TransformedEnv(
env_func(), transforms_training.clone()
)

# Initialize test env
self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
self.config.sampling_device
)
Expand All @@ -457,6 +450,29 @@ def _setup_task(self):
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(self.test_env)

# Add rnn transforms here so they do not show in the benchmarl specs
if self.model_config.is_rnn:
self.test_env = _add_rnn_transforms(
lambda: self.test_env, self.group_map, self.model_config
)()
env_func = _add_rnn_transforms(env_func, self.group_map, self.model_config)

# Initialize train env
if self.test_env.batch_size == ():
# If the environment is not vectorized, we simulate vectorization using parallel or serial environments
env_class = (
SerialEnv if not self.config.parallel_collection else ParallelEnv
)
self.env_func = lambda: TransformedEnv(
env_class(self.config.n_envs_per_worker(self.on_policy), env_func),
transforms_training.clone(),
)
else:
# Otherwise it is already vectorized
self.env_func = lambda: TransformedEnv(
env_func(), transforms_training.clone()
)

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)

Expand Down
56 changes: 55 additions & 1 deletion benchmarl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

import importlib
import random
from typing import Any, Dict, Union
import typing
from typing import Any, Callable, Dict, List, Union

import torch
import yaml
from torchrl.data import Composite
from torchrl.envs import Compose, EnvBase, InitTracker, TensorDictPrimer, TransformedEnv

if typing.TYPE_CHECKING:
from benchmarl.models import ModelConfig

_has_numpy = importlib.util.find_spec("numpy") is not None

Expand Down Expand Up @@ -53,3 +59,51 @@ def seed_everything(seed: int):
import numpy

numpy.random.seed(seed)


def _add_rnn_transforms(
env_fun: Callable[[], EnvBase],
group_map: Dict[str, List[str]],
model_config: "ModelConfig",
) -> Callable[[], EnvBase]:
"""
This function adds RNN specific transforms to the environment
Args:
env_fun (callable): a function that takes no args and creates an environment
group_map (Dict[str,List[str]]): the group_map of the agents
model_config (ModelConfig): the model configuration
Returns: a function that takes no args and creates an environment
"""

def model_fun():
env = env_fun()
spec_actor = model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in group_map.items()
}
)

out_env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return out_env

return model_fun
Loading

0 comments on commit e910a83

Please sign in to comment.