Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DataCollectorV0 and HDF5 dependencies isolation #133

Merged
merged 20 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
hooks:
- id: flake8
args:
- '--per-file-ignores=*/__init__.py:F401 test/all_parameter_combs_test.py:F405 pettingzoo/classic/go/go.py:W605'
- '--per-file-ignores=*/__init__.py:F401'
- --extend-ignore=E203
- --max-complexity=205
- --max-line-length=300
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- --count
# TODO: Remove ignoring rules D101, D102, D103, D105
- --add-ignore=D100,D107,D101,D102,D103,D105
exclude: "__init__.py$|^pettingzoo.test|^docs"
exclude: "__init__.py$|^docs"
additional_dependencies: ["toml"]
- repo: local
hooks:
Expand Down
23 changes: 12 additions & 11 deletions minari/data_collector/callbacks/episode_metadata.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import h5py
from typing import Dict

import numpy as np


class EpisodeMetadataCallback:
"""Callback to full episode after saving to hdf5 file as a group.

This callback can be overridden to add extra metadata attributes or statistics to
each HDF5 episode group in the Minari dataset. The custom callback can then be
each episode in the Minari dataset. The custom callback can then be
passed to the DataCollectorV0 wrapper to the `episode_metadata_callback` argument.

TODO: add more default statistics to episode datasets
"""

def __call__(self, eps_group: h5py.Group):
def __call__(self, episode: Dict):
"""Callback method.

Override this method to add custom attribute metadata to the episode group.

Args:
eps_group (h5py.Group): the HDF5 group that contains an episode's datasets
episode (dict): the dict that contains an episode's data
"""
eps_group["rewards"].attrs["sum"] = np.sum(eps_group["rewards"])
eps_group["rewards"].attrs["mean"] = np.mean(eps_group["rewards"])
eps_group["rewards"].attrs["std"] = np.std(eps_group["rewards"])
eps_group["rewards"].attrs["max"] = np.max(eps_group["rewards"])
eps_group["rewards"].attrs["min"] = np.min(eps_group["rewards"])

eps_group.attrs["total_steps"] = eps_group["rewards"].shape[0]
return {
"rewards_sum": np.sum(episode["rewards"]),
"rewards_mean": np.mean(episode["rewards"]),
"rewards_std": np.std(episode["rewards"]),
"rewards_max": np.max(episode["rewards"]),
"rewards_min": np.min(episode["rewards"]),
}
5 changes: 3 additions & 2 deletions minari/data_collector/callbacks/step_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class StepData(TypedDict):
"rewards",
"truncations",
"terminations",
"infos",
}


Expand Down Expand Up @@ -58,8 +59,8 @@ def __call__(self, env, **kwargs):

return step_data

The episode groups in the HDF5 file of this Minari dataset will contain a subgroup called `environment_states` with dataset `velocity` and another subgroup called `pose`
with datasets `position` and `orientation`
The Minari dataset will contain a dictionary called `environment_states` with `velocity` value and another dictionary `pose`
with `position` and `orientation`

Args:
env (gym.Env): current Gymnasium environment.
Expand Down
Loading