Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for different algorithms for different groups #156

Open
karthiks1701 opened this issue Jan 6, 2025 · 7 comments · May be fixed by #159
Open

Support for different algorithms for different groups #156

karthiks1701 opened this issue Jan 6, 2025 · 7 comments · May be fixed by #159

Comments

@karthiks1701
Copy link

karthiks1701 commented Jan 6, 2025

Thanks for this amazing repository, this is a great addition to the MARL research community.

I would like to know if there is support for using different algorithms for different groups (E.g., when half the agents use IPPO, and the other half use ISAC in the same environment) in the same experiment. By grouping I am referring to the shared abstraction introduced in BenchMARL/TorchRL. By extension, this would also require support for using different components (models, replay buffers, etc) per group as well. If there is no current support, what might be the steps one might need to take to have this functionality?

@matteobettini
Copy link
Collaborator

matteobettini commented Jan 7, 2025

TL;DR

  • For algorithms: create a custom algorithm that routes its abstract functions to different algorithms based on the group arg
  • For models: create a custom model config that routes its get_model() function to different model configs based on the group arg

Full reply

Hello! Thanks for opening this!

While there is currently no support in the interface for doing this, we have structured the library such that we can allow it easily. In fact, I already use something like this for my own runs.

Essentially, different groups already have separate replay bufferers, losses, optimizers, etc.

If you look at experiment._setup_algorithm() you can see how he algorithm is setup

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
self.losses = {
group: self.algorithm.get_loss_and_updater(group)[0]
for group in self.group_map.keys()
}
self.target_updaters = {
group: self.algorithm.get_loss_and_updater(group)[1]
for group in self.group_map.keys()
}
self.optimizers = {
group: {
loss_name: torch.optim.Adam(
params, lr=self.config.lr, eps=self.config.adam_eps
)
for loss_name, params in self.algorithm.get_parameters(group).items()
}
for group in self.group_map.keys()
}

You can see that every function of the algorithm class already takes group as a parameter.

I'll first discuss one possibility for having different algorithms for different groups and then different models

Different algorithms for each group

The easiest way to do this is to create a new custom algorithm (let's say IppoIsac) and then implement the abstractu functions routing them to the desired algorithm based on the given group. For example:

class IppoIsac(Algorithm):

    def __init__(**kwargs):
          self.ippo = Ippo(...)
          self.isac = Isac(....)

    def _get_loss(
        self, group: str, policy_for_loss: TensorDictModule, continuous: bool
    ) -> Tuple[LossModule, bool]:
        if goup == "attackers":
             return self.ippo._get_loss(,group, policy_for_loss, continuous)
        elif group == "defenders":
              return self.isac._get_loss(,group, policy_for_loss, continuous)

Different models for each group

For models it is even simpler: you can create a custom model config that reimplements get_model() and routes it to different models based on the group parameter.

Note that the choices of using different algorithms and models for different groups are completely decoupled and they can be taken independently of each other, following the BenchMARL philosophy (you could use the same algo for all groups with different models, or vice viersa)

We could provide this

Eventually, what I could do is provide custom classes called EnsembleAlgorithm and EnsembleModel that are configured by taking as input a dictonary mapping groups to AlgorithmConfigs or ModelConfigs.

But I tend to think that this would be a bit behind the scope of BenchMARL as it would be really hard to configure through hydra and it would require users to know what groups there are (aka know what env is been used).

I'll think about it but in the meantime let me know if the proposed solution works for you

@karthiks1701
Copy link
Author

Thanks for the quick reply, this solution should work for my use case (I am yet to implement and test it, but at first glance, this looks like the solution I was hoping for).

The EnsembleAlgorithm and EnsembleModel functionality could make BenchMARL stand out from other repositories and be extremely useful for heterogeneous MARL research, although I understand that prior knowledge of environment/groups might be necessary. One naive workaround might be to randomly assign groups to the required algorithms/models clusters, in case the user doesn't specify the assignment.

@matteobettini
Copy link
Collaborator

The EnsembleAlgorithm and EnsembleModel functionality could make BenchMARL stand out from other repositories and be extremely useful for heterogeneous MARL research, although I understand that prior knowledge of environment/groups might be necessary. One naive workaround might be to randomly assign groups to the required algorithms/models clusters, in case the user doesn't specify the assignment.

I think what I can do is implement them and provide them to the user. To configure them you can just pass the map of group names to component configs.

Since they are hard to configure via hydra I can in the first instance expose them only via the python inerface and write a section on the docs on how to use them for interested users

@matteobettini
Copy link
Collaborator

I have implemented the ensemble model and will soon do the algo in #159

as you can see it is just a few lines cause i already wanted to allow this in the lib

@karthiks1701
Copy link
Author

Thank you! I will take a look soon.

@matteobettini
Copy link
Collaborator

matteobettini commented Jan 7, 2025

I have pushed also EnsembleAlgorithm.

There is a point I overlooked regarding to this. Since the training loop are currently shared by all algos, it is not currently possible to mix on_policy and off_policy algos in the same ensemble. This is not something impossible to overcome but it will require a major restructoring of the lib so for now I am keeping this constraint.

However it is now possible to run stuff like this

from benchmarl.algorithms import EnsembleAlgorithmConfig, IppoConfig, MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig
from models import DeepsetsConfig, EnsembleModelConfig, GnnConfig

if __name__ == "__main__":

    # Loads from "benchmarl/conf/experiment/base_experiment.yaml"
    experiment_config = ExperimentConfig.get_from_yaml()

    # Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
    task = VmasTask.SIMPLE_TAG.get_from_yaml()

    algorithm_config = EnsembleAlgorithmConfig(
        {"agent": MappoConfig.get_from_yaml(), "adversary": IppoConfig.get_from_yaml()}
    )

    model_config = EnsembleModelConfig(
        {"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()}
    )
    critic_model_config = EnsembleModelConfig(
        {
            "agent": DeepsetsConfig.get_from_yaml(),
            "adversary": MlpConfig.get_from_yaml(),
        }
    )

    experiment = Experiment(
        task=task,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=0,
        config=experiment_config,
    )
    experiment.run()

@karthiks1701
Copy link
Author

Thanks, I think it should be good for my use case right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants