diff --git a/metaworld/__init__.py b/metaworld/__init__.py index a3255ebe0..a0ff4a971 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -571,40 +571,31 @@ def _ml_bench_vector_entry_point( **lamb_kwargs, ) - for name in ALL_V3_ENVIRONMENTS.keys(): - kwargs = {"name": name} - register( - id=f"Meta-World/{name}", - entry_point="metaworld:make_mt_envs", - kwargs=kwargs, - ) - for vector_strategy in ["sync", "async"]: - for split in ["train", "test"]: - register( - id=f"Meta-World/ML1-{split}-{name}-{vector_strategy}", - vector_entry_point=partial( - _ml_bench_vector_entry_point, name, split, vector_strategy - ), - kwargs={}, - ) + for vector_strategy in ["sync", "async"]: + for split in ["train", "test"]: + register( + id=f"Meta-World/ML1-{split}-{vector_strategy}", + vector_entry_point=lambda env_name, seed=None: partial( + _ml_bench_vector_entry_point, env_name, split, vector_strategy + ), + kwargs={}, + ) - for name_hid in ALL_V3_ENVIRONMENTS_GOAL_HIDDEN: - register( - id=f"Meta-World/{name_hid}", - entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[name_hid]( # type: ignore - seed=seed - ), - kwargs={}, - ) + register( + id=f"Meta-World/goal_hidden", + entry_point=lambda name_hid, seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[name_hid + '-goal-hidden' if '-goal-hidden' not in env_name else '']( # type: ignore + seed=seed + ), + kwargs={}, + ) - for name_obs in ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE: - register( - id=f"Meta-World/{name_obs}", - entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[name_obs]( # type: ignore - seed=seed - ), - kwargs={}, - ) + register( + id=f"Meta-World/goal_observable", + entry_point=lambda env_name, seed=None: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[env_name + '-goal-observable' if '-goal-observable' not in env_name else '']( # type: ignore + seed=seed + ), + kwargs={}, + ) for mt_bench in ["MT10", "MT50"]: for vector_strategy in ["sync", "async"]: