From 08a28d38dab17cb4656cd73517bf894575ce6f1a Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Fri, 17 Jan 2025 23:46:45 +0100 Subject: [PATCH] Fix `get_wrapper_attr` / `set_wrapper_attr`. (#1293) Co-authored-by: Mark Towers --- gymnasium/core.py | 19 ++++++++++++------- tests/test_core.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 9dfe63876..aaf9476f7 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -432,18 +432,23 @@ def set_wrapper_attr(self, name: str, value: Any): name: The variable name value: The new variable value """ - sub_env = self.env - attr_set = False + sub_env = self - while attr_set is False and isinstance(sub_env, Wrapper): + # loop through all the wrappers, checking if it has the variable name then setting it + # otherwise stripping the wrapper to check the next. + # end when the core env is reached + while isinstance(sub_env, Wrapper): if hasattr(sub_env, name): setattr(sub_env, name, value) - attr_set = True - else: - sub_env = sub_env.env + return + + sub_env = sub_env.env - if attr_set is False: + # check if the base environment has the wrapper, otherwise, we set it on the top (this) wrapper + if hasattr(sub_env, name): setattr(sub_env, name, value) + else: + setattr(self, name, value) def __str__(self): """Returns the wrapper name and the :attr:`env` representation string.""" diff --git a/tests/test_core.py b/tests/test_core.py index 196b64f73..7e7391aad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -215,6 +215,16 @@ def test_get_set_wrapper_attr(): env.unwrapped._disable_render_order_enforcing assert env.get_wrapper_attr("_disable_render_order_enforcing") is True + # Test with top-most wrapper + env.MY_ATTRIBUTE_1 = True + assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is True + env.set_wrapper_attr("MY_ATTRIBUTE_1", False) + assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is False + + # Test with non-existing attribute + env.set_wrapper_attr("MY_ATTRIBUTE_2", True) + assert getattr(env, "MY_ATTRIBUTE_2") is True + class TestRandomSeeding: @staticmethod