Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower committed Mar 13, 2024
1 parent f24c8d7 commit 88c51f2
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions test/pettingzoo_api_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import numpy as np
import pytest
from pettingzoo.butterfly import knights_archers_zombies_v10, pistonball_v6, cooperative_pong_v5
from pettingzoo.sisl import pursuit_v4
from pettingzoo.butterfly import (
cooperative_pong_v5,
knights_archers_zombies_v10,
pistonball_v6,
)
from pettingzoo.classic import connect_four_v3
from pettingzoo.mpe import simple_push_v3, simple_world_comm_v3, simple_spread_v3
from pettingzoo.mpe import simple_push_v3, simple_spread_v3, simple_world_comm_v3
from pettingzoo.sisl import pursuit_v4
from pettingzoo.test import api_test, parallel_api_test, seed_test
from pettingzoo.utils.all_modules import (
atari_environments,
butterfly_environments,
classic_environments,
mpe_environments,
sisl_environments,
)

import supersuit
from supersuit import (
Expand All @@ -17,19 +28,23 @@
from supersuit.utils.convert_box import convert_box


from pettingzoo.utils.all_modules import atari_environments, butterfly_environments, classic_environments, mpe_environments, sisl_environments
atari = list(atari_environments.values())
butterfly = list(butterfly_environments.values())
classic = list(classic_environments.values())
mpe = list(mpe_environments.values())
sisl = list(sisl_environments.values())
all = atari + butterfly + classic + mpe + sisl

BUTTERFLY_MPE_CLASSIC = [knights_archers_zombies_v10, simple_push_v3, connect_four_v3, simple_spread_v3]
BUTTERFLY_MPE_CLASSIC = [
knights_archers_zombies_v10,
simple_push_v3,
connect_four_v3,
simple_spread_v3,
]
BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3, simple_spread_v3]


@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_frame_stack(env_fn):
_env = env_fn.env()
wrapped_env = frame_stack_v2(_env)
Expand Down Expand Up @@ -84,13 +99,17 @@ def test_pad_action_space_parallel(env_fn):
parallel_api_test(wrapped_env)


@pytest.mark.parametrize("env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4])
@pytest.mark.parametrize(
"env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
)
def test_color_reduction(env_fn):
env = supersuit.color_reduction_v0(env_fn.env(), "R")
api_test(env)


@pytest.mark.parametrize("env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4])
@pytest.mark.parametrize(
"env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
)
def test_color_reduction_parallel(env_fn):
env = supersuit.color_reduction_v0(env_fn.parallel_env(), "R")
parallel_api_test(env)
Expand Down Expand Up @@ -120,7 +139,12 @@ def test_resize_dtype_parallel(env_fn, wrapper_kwargs):
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", atari + butterfly + [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"])
@pytest.mark.parametrize(
"env_fn",
atari
+ butterfly
+ [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"],
)
def test_dtype(env_fn):
env = supersuit.dtype_v0(env_fn.env(), np.int32)
api_test(env)
Expand Down Expand Up @@ -227,7 +251,12 @@ def test_reward_lambda_parallel(env_fn):
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", [v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"] + mpe + sisl)
@pytest.mark.parametrize(
"env_fn",
[v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"]
+ mpe
+ sisl,
)
def test_observation_lambda(env_fn):
env = supersuit.observation_lambda_v0(env_fn.env(), lambda obs, obs_space: obs - 1)
api_test(env)
Expand Down Expand Up @@ -319,6 +348,7 @@ def test_nan_zeros_parallel(env_fn):
env = supersuit.nan_zeros_v0(env_fn.parallel_env())
parallel_api_test(env)


# Note: hanabi v5 fails here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_nan_random(env_fn):
Expand All @@ -331,6 +361,7 @@ def test_nan_random_parallel(env_fn):
env = supersuit.nan_random_v0(env_fn.parallel_env())
parallel_api_test(env)


# Note: hanabi v5 fails here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_sticky_actions(env_fn):
Expand All @@ -343,6 +374,7 @@ def test_sticky_actions_parallel(env_fn):
env = supersuit.sticky_actions_v0(env_fn.parallel_env(), 0.75)
parallel_api_test(env)


# Note: hanabi_v5 and texas_holdem_v4 fail here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_delay_observations(env_fn):
Expand Down

0 comments on commit 88c51f2

Please sign in to comment.