Skip to content

Commit

Permalink
fix test after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Oct 12, 2023
1 parent 83fe18a commit 47c4a68
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 65 deletions.
49 changes: 0 additions & 49 deletions tests/utils/test_dataset_combine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import gymnasium as gym
import h5py
import pytest
from gymnasium.utils.env_checker import data_equivalence
from packaging.specifiers import SpecifierSet
Expand Down Expand Up @@ -204,51 +203,3 @@ def test_combine_minari_version_specifiers(specifier_intersection, version_speci
intersection = combine_minari_version_specifiers(version_specifiers)

assert specifier_intersection == intersection


# in the future, if the logic of save metadata of combined dataset changes, this should be changed as well
def test_combine_dataset_with_different_metadata():
n_data = 2
dataset_list = []
for i in range(n_data):
dataset_id = f"cartpole-test-{i}-v0"
env = gym.make("CartPole-v1", max_episode_steps=500)
env = DataCollectorV0(env)
env.reset(seed=42)
for episode in range(5):
terminated = False
truncated = False
while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
_, _, terminated, truncated, _ = env.step(action)
env.reset()

# Create Minari dataset and store locally
permalink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
dataset = minari.create_dataset_from_collector_env(
dataset_id=dataset_id,
collector_env=env,
algorithm_name="random_policy" + str(i),
code_permalink=permalink + str(i),
author="WillDudley" + str(i),
author_email="[email protected]" + str(i),
)
assert isinstance(dataset, MinariDataset)
env.close()
dataset_list.append(dataset)

combined_dataset = combine_datasets(
dataset_list, new_dataset_id="cartpole-combined-test-v0"
)
permalink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
with h5py.File(combined_dataset.spec.data_path) as dt_file:
assert dt_file.attrs["algorithm_name"] == "random_policy" + str(n_data - 1)
_final_code_link = permalink + str(n_data - 1)
assert dt_file.attrs["code_permalink"] == _final_code_link
assert dt_file.attrs["author"] == "WillDudley" + str(n_data - 1)
assert dt_file.attrs["author_email"] == "[email protected]" + str(n_data - 1)

for i in range(n_data):
minari.delete_dataset(f"cartpole-test-{i}-v0")
minari.delete_dataset("cartpole-combined-test-v0")
return
28 changes: 12 additions & 16 deletions tests/utils/test_dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict

import gymnasium as gym
import h5py
import numpy as np
import pytest
from gymnasium import spaces
Expand Down Expand Up @@ -70,14 +69,12 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id):
author_email="[email protected]",
)

# test metadata

with h5py.File(dataset.spec.data_path, "r") as data_file:
assert data_file.attrs["algorithm_name"] == "random_policy"
codelink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
assert data_file.attrs["code_permalink"] == codelink
assert data_file.attrs["author"] == "WillDudley"
assert data_file.attrs["author_email"] == "[email protected]"
metadata = dataset.storage.metadata
assert metadata["algorithm_name"] == "random_policy"
codelink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
assert metadata["code_permalink"] == codelink
assert metadata["author"] == "WillDudley"
assert metadata["author_email"] == "[email protected]"

assert isinstance(dataset, MinariDataset)
assert dataset.total_episodes == num_episodes
Expand Down Expand Up @@ -287,13 +284,12 @@ def test_generate_dataset_with_space_subset_external_buffer():
observation_space=observation_space_subset,
)

# test metadata
with h5py.File(dataset.spec.data_path, "r") as data_file:
assert data_file.attrs["algorithm_name"] == "random_policy"
code_link = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
assert data_file.attrs["code_permalink"] == code_link
assert data_file.attrs["author"] == "WillDudley"
assert data_file.attrs["author_email"] == "[email protected]"
metadata = dataset.storage.metadata
assert metadata["algorithm_name"] == "random_policy"
code_link = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py"
assert metadata["code_permalink"] == code_link
assert metadata["author"] == "WillDudley"
assert metadata["author_email"] == "[email protected]"

assert isinstance(dataset, MinariDataset)
assert dataset.total_episodes == num_episodes
Expand Down

0 comments on commit 47c4a68

Please sign in to comment.