diff --git a/metadrive/tests/test_functionality/test_nondeterminism.py b/metadrive/tests/test_functionality/test_nondeterminism.py index 95923b75e..57baf4c49 100644 --- a/metadrive/tests/test_functionality/test_nondeterminism.py +++ b/metadrive/tests/test_functionality/test_nondeterminism.py @@ -5,8 +5,9 @@ Usage: run this file. or pytest. """ +from collections import defaultdict + import numpy as np -import pandas as pd import pytest from metadrive.envs.metadrive_env import MetaDriveEnv @@ -42,22 +43,21 @@ def assert_dict_almost_equal(dict1, dict2, tol=1e-3): def are_traces_deterministic(traces) -> bool: - df = pd.DataFrame(traces) - - # grouping by repetition to get a list of traces - traces = df.groupby("repetition") - - # drop index and repetition ID to compare only step info later - stripped_traces = [trace.reset_index(drop=True).drop("repetition", axis=1) for _, trace in traces] + # Group traces by repetition + grouped_traces = defaultdict(list) + for trace in traces: + repetition_id = trace["repetition"] + grouped_traces[repetition_id].append({k: v for k, v in trace.items() if k != "repetition"}) - # iterate over each trace and check if it is equal to the first one - are_equal_to_first_trace = [trace.equals(stripped_traces[0]) for trace in stripped_traces] + # Convert traces to lists of dictionaries + stripped_traces = [sorted(group, key=lambda x: sorted(x.items())) for group in grouped_traces.values()] - first_trace = stripped_traces[0].to_dict() + # Compare each trace list to the first one + first_trace = stripped_traces[0] # This is a list of dictionaries for trace in stripped_traces: - # Assert - trace = trace.to_dict() - assert_dict_almost_equal(first_trace, trace) + assert len(trace) == len(first_trace) + for i in range(len(first_trace)): + assert_dict_almost_equal(first_trace[i], trace[i]) @pytest.mark.parametrize(