From 972d0cc170e10579f8942e74c95a88e702cf51a1 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 23 Aug 2023 12:15:01 +0200 Subject: [PATCH 01/19] init draft PR --- minari/__init__.py | 2 +- .../callbacks/episode_metadata.py | 18 +- minari/data_collector/data_collector.py | 354 +++--------------- minari/dataset/episode_data.py | 51 +++ minari/dataset/minari_dataset.py | 268 +++++++------ minari/dataset/minari_storage.py | 284 ++++---------- minari/storage/local.py | 1 + minari/utils.py | 332 ++++++++-------- tests/common.py | 4 +- 9 files changed, 516 insertions(+), 798 deletions(-) create mode 100644 minari/dataset/episode_data.py diff --git a/minari/__init__.py b/minari/__init__.py index df7c35ff..a65fe4a3 100644 --- a/minari/__init__.py +++ b/minari/__init__.py @@ -9,7 +9,7 @@ from minari.storage.local import delete_dataset, list_local_datasets, load_dataset from minari.utils import ( combine_datasets, - create_dataset_from_buffers, + # create_dataset_from_buffers, create_dataset_from_collector_env, get_normalized_score, split_dataset, diff --git a/minari/data_collector/callbacks/episode_metadata.py b/minari/data_collector/callbacks/episode_metadata.py index 3af389ba..5779cdd2 100644 --- a/minari/data_collector/callbacks/episode_metadata.py +++ b/minari/data_collector/callbacks/episode_metadata.py @@ -1,4 +1,4 @@ -import h5py +from typing import Dict import numpy as np @@ -12,7 +12,7 @@ class EpisodeMetadataCallback: TODO: add more default statistics to episode datasets """ - def __call__(self, eps_group: h5py.Group): + def __call__(self, episode: Dict): """Callback method. Override this method to add custom attribute metadata to the episode group. @@ -20,10 +20,10 @@ def __call__(self, eps_group: h5py.Group): Args: eps_group (h5py.Group): the HDF5 group that contains an episode's datasets """ - eps_group["rewards"].attrs["sum"] = np.sum(eps_group["rewards"]) - eps_group["rewards"].attrs["mean"] = np.mean(eps_group["rewards"]) - eps_group["rewards"].attrs["std"] = np.std(eps_group["rewards"]) - eps_group["rewards"].attrs["max"] = np.max(eps_group["rewards"]) - eps_group["rewards"].attrs["min"] = np.min(eps_group["rewards"]) - - eps_group.attrs["total_steps"] = eps_group["rewards"].shape[0] + return { + "rewards_sum": np.sum(episode["rewards"]), + "rewards_mean": np.mean(episode["rewards"]), + "rewards_std": np.std(episode["rewards"]), + "rewards_max": np.max(episode["rewards"]), + "rewards_min": np.min(episode["rewards"]) + } diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index e1e2bb4d..06a0b5a1 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -3,12 +3,9 @@ import os import shutil import tempfile -from collections import OrderedDict from typing import Any, Dict, List, Optional, SupportsFloat, Type, Union import gymnasium as gym -import h5py -import numpy as np from gymnasium.core import ActType, ObsType from minari.data_collector.callbacks import ( @@ -17,6 +14,7 @@ StepData, StepDataCallback, ) +from minari.dataset.minari_storage import MinariStorage from minari.serialization import serialize_space @@ -72,7 +70,6 @@ def __init__( ] = EpisodeMetadataCallback, record_infos: bool = False, max_buffer_steps: Optional[int] = None, - max_buffer_episodes: Optional[int] = None, observation_space=None, action_space=None, ): @@ -84,13 +81,13 @@ def __init__( episode_metadata_callback (type[EpisodeMetadataCallback], optional): Callback class to add custom metadata to episode group in HDF5 file. Defaults to EpisodeMetadataCallback. record_infos (bool, optional): If True record the info return key of each step. Defaults to False. max_buffer_steps (Optional[int], optional): number of steps saved in-memory buffers before dumping to HDF5 file in disk. Defaults to None. - max_buffer_episodes (Optional[int], optional): number of episodes saved in-memory buffers before dumping to HDF5 file in disk. Defaults to None. Raises: ValueError: `max_buffer_steps` and `max_buffer_episodes` can't be passed at the same time """ super().__init__(env) self._step_data_callback = step_data_callback() + self._episode_metadata_callback = episode_metadata_callback() if observation_space is None: observation_space = self.env.observation_space @@ -100,22 +97,14 @@ def __init__( action_space = self.env.action_space self.dataset_action_space = action_space - self._episode_metadata_callback = episode_metadata_callback() self._record_infos = record_infos - - if max_buffer_steps is not None and max_buffer_episodes is not None: - raise ValueError("Choose step or episode scheduler not both") - - self.max_buffer_episodes = max_buffer_episodes self.max_buffer_steps = max_buffer_steps # Initialzie empty buffer - self._buffer: List[EpisodeBuffer] = [{}] + self._buffer: List[EpisodeBuffer] = [] - self._current_seed: Union[int, str] = str(None) - self._new_episode = False - - self._step_id = 0 + self._step_id = -1 + self._episode_id = -1 # get path to minari datasets directory self.datasets_path = os.environ.get("MINARI_DATASETS_PATH") @@ -128,23 +117,13 @@ def __init__( os.makedirs(self.datasets_path) self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._tmp_f = h5py.File( - os.path.join(self._tmp_dir.name, "tmp_dataset.hdf5"), "a", track_order=True - ) # track insertion order of groups ('episodes_i') - + self._storage = MinariStorage(self._tmp_dir.name) assert self.env.spec is not None, "Env Spec is None" - self._tmp_f.attrs["env_spec"] = self.env.spec.to_json() - - self._new_episode = False - self._reset_called = False - - # Initialize first episode group in temporary hdf5 file - self._episode_id = 0 - self._eps_group: h5py.Group = self._tmp_f.create_group("episode_0") - self._eps_group.attrs["id"] = 0 - - self._last_episode_group_term_or_trunc = False - self._last_episode_n_steps = 0 + self._storage.update_metadata({ + "action_space": serialize_space(self.dataset_action_space), + "observation_space": serialize_space(self.dataset_observation_space), + "env_spec": self.env.spec.to_json() + }) def _add_to_episode_buffer( self, @@ -194,9 +173,8 @@ def step( """Gymnasium step method.""" obs, rew, terminated, truncated, info = self.env.step(action) - # add/edit data from step and convert to dictionary step data step_data = self._step_data_callback( - env=self, + env=self.env, obs=obs, info=info, action=action, @@ -204,12 +182,9 @@ def step( terminated=terminated, truncated=truncated, ) - # Force step data dictionary to include keys corresponding to Gymnasium step returns: - # actions, observations, rewards, terminations, truncations, and infos assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'." - # Check that the saved observation and action belong to the dataset's observation/action spaces assert self.dataset_observation_space.contains( step_data["observations"] ), "Observations are not in observation space." @@ -218,55 +193,16 @@ def step( ), "Actions are not in action space." self._step_id += 1 + if ( + self.max_buffer_steps is not None + and self._step_id != 0 + and self._step_id % self.max_buffer_steps == 0 + ): + self._storage.update_episodes(self._buffer) + self._buffer = [{"id": self._episode_id}] - clear_buffers = False - # check if buffer needs to be cleared to temp file due to maximum step scheduler - if self.max_buffer_steps is not None: - clear_buffers = ( - self._step_id % self.max_buffer_steps == 0 and self._step_id != 0 - ) - - # Get initial observation/info from previous episode if reset has not been called after termination - # or truncation. This may happen if the step_data_callback truncates or terminates the episode under - # certain conditions. - if self._new_episode and not self._reset_called: - if isinstance(self._previous_eps_final_obs, dict): - self._buffer[-1]["observations"] = self._add_to_episode_buffer( - {}, self._previous_eps_final_obs - ) - else: - self._buffer[-1]["observations"] = [self._previous_eps_final_obs] - if self._record_infos: - self._buffer[-1]["infos"] = self._add_to_episode_buffer( - {}, self._previous_eps_final_info - ) - - self._new_episode = False - - # add step data to last episode buffer self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) - if step_data["terminations"] or step_data["truncations"]: - # Save last observation/info to use as initial observation/info in the next episode - self._previous_eps_final_obs = step_data["observations"] - if self._record_infos: - self._previous_eps_final_info = step_data["infos"] - self._reset_called = False - self._new_episode = True - self._buffer[-1]["seed"] = self._current_seed # type: ignore - # Only check episode scheduler to save in-memory data to temp HDF5 file when episode is done - if self.max_buffer_episodes is not None: - clear_buffers = (self._episode_id + 1) % self.max_buffer_episodes == 0 - - if clear_buffers: - self.clear_buffer_to_tmp_file() - - if clear_buffers or step_data["terminations"] or step_data["truncations"]: - self._buffer.append({}) - - if step_data["terminations"] or step_data["truncations"]: - self._episode_id += 1 - return obs, rew, terminated, truncated, info def reset( @@ -277,249 +213,62 @@ def reset( ) -> tuple[ObsType, dict[str, Any]]: """Gymnasium environment reset.""" obs, info = self.env.reset(seed=seed, options=options) - step_data = self._step_data_callback(env=self, obs=obs, info=info) + step_data = self._step_data_callback(env=self.env, obs=obs, info=info) + self._episode_id += 1 assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'" - # If last episode in global buffer has saved steps, we need to check if it was truncated or terminated - # If the last element in the buffer is not an empty dictionary, then we need to auto-truncate the episode. - if self._buffer[-1]: - if ( - not self._buffer[-1]["terminations"][-1] - and not self._buffer[-1]["truncations"][-1] - ): - self._buffer[-1]["truncations"][-1] = True - self._buffer[-1]["seed"] = self._current_seed # type: ignore - - # New episode - self._episode_id += 1 - - if ( - self.max_buffer_episodes is not None - and self._episode_id % self.max_buffer_episodes == 0 - ): - self.clear_buffer_to_tmp_file() - - # add new episode buffer - self._buffer.append({}) - else: - # In the case that the past episode is already stored in the tmp hdf5 file because of caching, - # we need to check if it was truncated or terminated, if not then auto-truncate - if ( - len(self._buffer) == 1 - and not self._last_episode_group_term_or_trunc - and self._episode_id != 0 - ): - self._eps_group["truncations"][-1] = True - self._last_episode_group_term_or_trunc = True - self._eps_group.attrs["seed"] = self._current_seed - - # New episode - self._episode_id += 1 - - # Compute metadata, use episode dataset in hdf5 file - self._episode_metadata_callback(self._eps_group) - - self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) - - if seed is None: - self._current_seed = str(None) - else: - self._current_seed = seed - - self._reset_called = True - + # Truncate the last episode in the buffer if it is not done + if ( + len(self._buffer) > 0 + and not self._buffer[-1]["terminations"][-1] + and not self._buffer[-1]["truncations"][-1] + ): + self._buffer[-1]["truncations"][-1] = True + + episode_buffer = { + "seed": seed if seed else str(None), + "id": self._episode_id + } + episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data) + self._buffer.append(episode_buffer) return obs, info - def clear_buffer_to_tmp_file(self, truncate_last_episode: bool = False): - """Save the global buffer in-memory to a temporary HDF5 file in disk. - - Args: - truncate_last_episode (bool, optional): If True the last episode from the buffer will be truncated before saving to disk. Defaults to False. - """ - - def get_h5py_subgroup(group: h5py.Group, name: str) -> h5py.Group: - """Get a subgroup from an h5py group. - - If the subgroup does not exist, create and return and empty group with the given name. - - Args: - group (h5py.Group): the h5py group object to look for/create the subgroup. - name (str): name of the subgroup. - - Returns: - subgroup (h5py.Group) - """ - if name in group: - subgroup = group.get(name) - assert isinstance(subgroup, h5py.Group) - else: - subgroup = group.create_group(name) - - return subgroup - - def clear_buffer(dictionary_buffer: EpisodeBuffer, episode_group: h5py.Group): - """Inner function to recursively save the nested data dictionaries in an episode buffer. - - Args: - dictionary_buffer (EpisodeBuffer): ditionary with keys to store as independent HDF5 datasets if the value is a list buffer - or create another group if value is a dictionary. - episode_group (h5py.Group): HDF5 group to store the datasets from the dictionary_buffer. - """ - for key, data in dictionary_buffer.items(): - if isinstance(data, dict): - eps_group_to_clear = get_h5py_subgroup(episode_group, key) - clear_buffer(data, eps_group_to_clear) - elif all(map(lambda elem: isinstance(elem, tuple), data)): - # we have a list of tuples, so we need to act appropriately - dict_data = { - f"_index_{str(i)}": [entry[i] for entry in data] - for i, _ in enumerate(data[0]) - } - eps_group_to_clear = get_h5py_subgroup(episode_group, key) - clear_buffer(dict_data, eps_group_to_clear) - elif all(map(lambda elem: isinstance(elem, OrderedDict), data)): - # we have a list of OrderedDicts, so we need to act appropriately - dict_data = { - key: [entry[key] for entry in data] - for key, value in data[0].items() - } - eps_group_to_clear = get_h5py_subgroup(episode_group, key) - clear_buffer(dict_data, eps_group_to_clear) - else: - if all(map(lambda elem: isinstance(elem, str), data)): - data_shape = (len(data),) - dtype = h5py.string_dtype(encoding="utf-8") - else: - data = np.asarray(data) - data_shape = data.shape - dtype = data.dtype - assert np.all( - np.logical_not(np.isnan(data)) - ), "Nan found after cast to nump array, check the type of 'data'." - - if ( - not self._last_episode_group_term_or_trunc - and key in episode_group - ): - current_dataset_shape = episode_group[key].shape[0] - episode_group[key].resize( - current_dataset_shape + len(data), axis=0 - ) - episode_group[key][-len(data) :] = data - else: - if not current_episode_group_term_or_trunc: - data_shape = (None,) + data_shape[1:] # resizable dataset - - episode_group.create_dataset( - key, data=data, maxshape=data_shape, dtype=dtype - ) - - for i, eps_buff in enumerate(self._buffer): - # Make sure that the episode has stepped, by checking if the 'actions' key has been added to the episode buffer. - if "actions" not in eps_buff: - continue - - current_episode_group_term_or_trunc = ( - eps_buff["terminations"][-1] or eps_buff["truncations"][-1] - ) - - # Check if last episode group is terminated or truncated - if self._last_episode_group_term_or_trunc: - # Add new episode group - current_episode_id = self._episode_id + i + 1 - len(self._buffer) - self._eps_group = self._tmp_f.create_group( - f"episode_{current_episode_id}" - ) - self._eps_group.attrs["id"] = current_episode_id - - if current_episode_group_term_or_trunc: - # Add seed to episode metadata if the current episode has finished - # Remove seed key from episode buffer before storing datasets to file - self._eps_group.attrs["seed"] = eps_buff.pop("seed") - clear_buffer(eps_buff, self._eps_group) - - if not self._last_episode_group_term_or_trunc: - self._last_episode_n_steps += len(eps_buff["actions"]) - else: - self._last_episode_n_steps = len(eps_buff["actions"]) - - if current_episode_group_term_or_trunc: - # Compute metadata, use episode dataset in hdf5 file - self._episode_metadata_callback(self._eps_group) - - self._last_episode_group_term_or_trunc = current_episode_group_term_or_trunc - - if not self._last_episode_group_term_or_trunc and truncate_last_episode: - self._eps_group["truncations"][-1] = True - self._last_episode_group_term_or_trunc = True - self._eps_group.attrs["seed"] = self._current_seed - - # New episode - self._episode_id += 1 - - # Compute metadata, use episode dataset in hdf5 file - self._episode_metadata_callback(self._eps_group) - - # Clear in-memory buffers - self._buffer.clear() - - def save_to_disk( - self, path: str, dataset_metadata: Optional[Dict[str, Any]] = None - ): + def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): """Save all in-memory buffer data and move temporary HDF5 file to a permanent location in disk. Args: path (str): path to store permanent HDF5, i.e: '/home/foo/datasets/data.hdf5' dataset_metadata (Dict, optional): additional metadata to add to HDF5 dataset file as attributes. Defaults to {}. """ - if dataset_metadata is None: - dataset_metadata = {} - - # Dump everything in memory buffers to tmp_dataset.hdf5 and truncate last episode - self.clear_buffer_to_tmp_file(truncate_last_episode=True) - - for key, value in dataset_metadata.items(): - self._tmp_f.attrs[key] = value + # truncate last episode + if not self._buffer[-1]["terminations"][-1] and not self._buffer[-1]["truncations"][-1]: + self._buffer[-1]["truncations"][-1] = True + + self._storage.update_episodes(self._buffer) #TODO: update add episode to use episode_id + self._buffer.clear() assert ( "observation_space" not in dataset_metadata.keys() ), "'observation_space' is not allowed as an optional key." assert ( "action_space" not in dataset_metadata.keys() - ), "'action_space' is not allowed as an optional key." - - action_space_str = serialize_space(self.dataset_action_space) - observation_space_str = serialize_space(self.dataset_observation_space) - - self._tmp_f.attrs["action_space"] = action_space_str - self._tmp_f.attrs["observation_space"] = observation_space_str + ), "'action_space' is not allowed as an optional key." + assert ( + "env_spec" not in dataset_metadata.keys() + ), "'env_spec' is not allowed as an optional key." + self._storage.update_metadata(dataset_metadata) - self._buffer.append({}) + episode_metadata = self._storage.apply(self._episode_metadata_callback) + self._storage.update_episode_metadata(episode_metadata) + shutil.move(str(self._storage.data_path), path) + # Reset episode count self._episode_id = 0 - self._tmp_f.attrs["total_episodes"] = len(self._tmp_f.keys()) - self._tmp_f.attrs["total_steps"] = sum( - [ - episode_group.attrs["total_steps"] - for episode_group in self._tmp_f.values() - ] - ) - - # Close tmp_dataset.hdf5 - self._tmp_f.close() - - # Move tmp_dataset.hdf5 to specified directory - shutil.move(os.path.join(self._tmp_dir.name, "tmp_dataset.hdf5"), path) - - self._tmp_f = h5py.File( - os.path.join(self._tmp_dir.name, "tmp_dataset.hdf5"), "a", track_order=True - ) - def close(self): """Close the Gymnasium environment. @@ -531,5 +280,4 @@ def close(self): self._buffer.clear() # Close tmp_dataset.hdf5 - self._tmp_f.close() shutil.rmtree(self._tmp_dir.name) diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py new file mode 100644 index 00000000..db2b02e3 --- /dev/null +++ b/minari/dataset/episode_data.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Optional +import numpy as np + + +@dataclass(frozen=True) +class EpisodeData: + """Contains the datasets data for a single episode. + + This is the object returned by :class:`minari.MinariDataset.sample_episodes`. + """ + + id: int + seed: Optional[int] + total_timesteps: int + observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + terminations: np.ndarray + truncations: np.ndarray + + def __repr__(self) -> str: + return ( + "EpisodeData(" + f"id={repr(self.id)}, " + f"seed={repr(self.seed)}, " + f"total_timesteps={self.total_timesteps}, " + f"observations={EpisodeData._repr_space_values(self.observations)}, " + f"actions={EpisodeData._repr_space_values(self.actions)}, " + f"rewards=ndarray of {len(self.rewards)} floats, " + f"terminations=ndarray of {len(self.terminations)} bools, " + f"truncations=ndarray of {len(self.truncations)} bools" + ")" + ) + + @staticmethod + def _repr_space_values(value): + if isinstance(value, np.ndarray): + return f"ndarray of shape {value.shape} and dtype {value.dtype}" + elif isinstance(value, dict): + reprs = [ + f"{k}: {EpisodeData._repr_space_values(v)}" for k, v in value.items() + ] + dict_repr = ", ".join(reprs) + return "{" + dict_repr + "}" + elif isinstance(value, tuple): + reprs = [EpisodeData._repr_space_values(v) for v in value] + values_repr = ", ".join(reprs) + return "(" + values_repr + ")" + else: + return repr(value) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 8705a71b..4bb0b315 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib.metadata +import json import os import re from dataclasses import dataclass, field @@ -9,11 +11,17 @@ import numpy as np from gymnasium import error from gymnasium.envs.registration import EnvSpec +from packaging.specifiers import InvalidSpecifier, SpecifierSet +from packaging.version import Version -from minari.data_collector import DataCollectorV0 from minari.dataset.minari_storage import MinariStorage, PathLike +from minari.dataset.episode_data import EpisodeData + +# Use importlib due to circular import when: "from minari import __version__" +__version__ = importlib.metadata.version("minari") + DATASET_ID_RE = re.compile( r"(?:(?P[\w]+?))?(?:-(?P[\w:.-]+?))(?:-v(?P\d+))?$" ) @@ -41,54 +49,6 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]: return env_name, dataset_name, version -@dataclass(frozen=True) -class EpisodeData: - """Contains the datasets data for a single episode. - - This is the object returned by :class:`minari.MinariDataset.sample_episodes`. - """ - - id: int - seed: Optional[int] - total_timesteps: int - observations: np.ndarray - actions: np.ndarray - rewards: np.ndarray - terminations: np.ndarray - truncations: np.ndarray - - def __repr__(self) -> str: - return ( - "EpisodeData(" - f"id={repr(self.id)}, " - f"seed={repr(self.seed)}, " - f"total_timesteps={self.total_timesteps}, " - f"observations={EpisodeData._repr_space_values(self.observations)}, " - f"actions={EpisodeData._repr_space_values(self.actions)}, " - f"rewards=ndarray of {len(self.rewards)} floats, " - f"terminations=ndarray of {len(self.terminations)} bools, " - f"truncations=ndarray of {len(self.truncations)} bools" - ")" - ) - - @staticmethod - def _repr_space_values(value): - if isinstance(value, np.ndarray): - return f"ndarray of shape {value.shape} and dtype {value.dtype}" - elif isinstance(value, dict): - reprs = [ - f"{k}: {EpisodeData._repr_space_values(v)}" for k, v in value.items() - ] - dict_repr = ", ".join(reprs) - return "{" + dict_repr + "}" - elif isinstance(value, tuple): - reprs = [EpisodeData._repr_space_values(v) for v in value] - values_repr = ", ".join(reprs) - return "(" + values_repr + ")" - else: - return repr(value) - - @dataclass class MinariDatasetSpec: env_spec: EnvSpec @@ -138,10 +98,65 @@ def __init__( else: raise ValueError(f"Unrecognized type {type(data)} for data") - self._additional_data_id = 0 + metadata = self._data.metadata + + env_spec = metadata["env_spec"] + assert isinstance(env_spec, str) + self._env_spec = EnvSpec.from_json(env_spec) + + dataset_id = metadata["dataset_id"] + assert isinstance(dataset_id, str) + self._dataset_id = dataset_id + + minari_version = metadata["minari_version"] + assert isinstance(minari_version, str) + + # Check that the dataset is compatible with the current version of Minari + try: + assert Version(__version__) in SpecifierSet( + minari_version + ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." + self._minari_version = minari_version + except InvalidSpecifier: + print(f"{minari_version} is not a version specifier.") + + self._combined_datasets = metadata.get("combined_datasets", []) + + # We will default to using the reconstructed observation and action spaces from the dataset + # and fall back to the env spec env if the action and observation spaces are not both present + # in the dataset. + observation_space = metadata.get("observation_space") + action_space = metadata.get("action_space") + if observation_space is None or action_space is None: + # Checking if the base library of the environment is present in the environment + entry_point = json.loads(env_spec)["entry_point"] + lib_full_path = entry_point.split(":")[0] + base_lib = lib_full_path.split(".")[0] + env_name = self._env_spec.id + + try: + env = gym.make(self._env_spec) + if observation_space is None: + observation_space = env.observation_space + if action_space is None: + action_space = env.action_space + env.close() + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Install {base_lib} for loading {env_name} data" + ) from e + assert isinstance(observation_space, gym.spaces.Space) + assert isinstance(action_space, gym.spaces.Space) + self._observation_space = observation_space + self._action_space = action_space + if episode_indices is None: - episode_indices = np.arange(self._data.total_episodes) - total_steps = self._data.total_steps + total_episodes = metadata["total_episodes"] + assert isinstance(total_episodes, np.ndarray) + episode_indices = np.arange(total_episodes.item()) + total_steps = metadata["total_steps"] + assert isinstance(total_steps, np.ndarray) + total_steps = total_steps.item() else: total_steps = sum( self._data.apply( @@ -149,48 +164,34 @@ def __init__( episode_indices=episode_indices, ) ) - - self._episode_indices = episode_indices + + assert isinstance(episode_indices, np.ndarray) + self._episode_indices: np.ndarray = episode_indices + assert isinstance(total_steps, int) + self._total_steps = total_steps assert self._episode_indices is not None self.spec = MinariDatasetSpec( - env_spec=self._data.env_spec, + env_spec=self.env_spec, total_episodes=self._episode_indices.size, total_steps=total_steps, - dataset_id=self._data.id, - combined_datasets=self._data.combined_datasets, - observation_space=self._data.observation_space, - action_space=self._data.action_space, + dataset_id=self.id, + combined_datasets=self.combined_datasets, + observation_space=self.observation_space, + action_space=self.action_space, data_path=str(self._data.data_path), - minari_version=str(self._data.minari_version), + minari_version=str(self.minari_version), ) - self._total_steps = total_steps self._generator = np.random.default_rng() - @property - def total_episodes(self): - """Total episodes recorded in the Minari dataset.""" - assert self._episode_indices is not None - return self._episode_indices.size - - @property - def total_steps(self): - """Total episodes steps in the Minari dataset.""" - return self._total_steps - - @property - def episode_indices(self) -> np.ndarray: - """Indices of the available episodes to sample within the Minari dataset.""" - return self._episode_indices - def recover_environment(self) -> gym.Env: """Recover the Gymnasium environment used to create the dataset. Returns: environment: Gymnasium environment """ - return gym.make(self._data.env_spec) + return gym.make(self.env_spec) def set_seed(self, seed: int): """Set seed for random episode sampling generator.""" @@ -253,44 +254,44 @@ def iterate_episodes( data = self._data.get_episodes([episode_index])[0] yield EpisodeData(**data) - def update_dataset_from_collector_env(self, collector_env: DataCollectorV0): - """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). + # def update_dataset_from_collector_env(self, collector_env: DataCollectorV0): + # """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). - This method can be used as a checkpoint when creating a dataset. - A new HDF5 file will be created with the new dataset file in the same directory as `main_data.hdf5` called - `additional_data_i.hdf5`. Both datasets are joined together by creating external links to each additional - episode group: https://docs.h5py.org/en/stable/high/group.html#external-links + # This method can be used as a checkpoint when creating a dataset. + # A new HDF5 file will be created with the new dataset file in the same directory as `main_data.hdf5` called + # `additional_data_i.hdf5`. Both datasets are joined together by creating external links to each additional + # episode group: https://docs.h5py.org/en/stable/high/group.html#external-links - Args: - collector_env (DataCollectorV0): Collector environment - """ - # check that collector env has the same characteristics as self._env_spec - new_data_file_path = os.path.join( - os.path.split(self.spec.data_path)[0], - f"additional_data_{self._additional_data_id}.hdf5", - ) + # Args: + # collector_env (DataCollectorV0): Collector environment + # """ + # # check that collector env has the same characteristics as self._env_spec + # new_data_file_path = os.path.join( + # os.path.split(self.spec.data_path)[0], + # f"additional_data_{self._additional_data_id}.hdf5", + # ) - old_total_episodes = self._data.total_episodes + # old_total_episodes = self._data.total_episodes - self._data.update_from_collector_env( - collector_env, new_data_file_path, self._additional_data_id - ) + # self._data.update_from_collector_env( + # collector_env, new_data_file_path, self._additional_data_id + # ) - new_total_episodes = self._data._total_episodes + # new_total_episodes = self._data._total_episodes - self._additional_data_id += 1 + # self._additional_data_id += 1 - self._episode_indices = np.append( - self._episode_indices, np.arange(old_total_episodes, new_total_episodes) - ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) + # self._episode_indices = np.append( + # self._episode_indices, np.arange(old_total_episodes, new_total_episodes) + # ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) - self.spec.total_episodes = self._episode_indices.size - self.spec.total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=self._episode_indices, - ) - ) + # self.spec.total_episodes = self._episode_indices.size + # self.spec.total_steps = sum( + # self._data.apply( + # lambda episode: episode["total_timesteps"], + # episode_indices=self._episode_indices, + # ) + # ) def update_dataset_from_buffer(self, buffer: List[dict]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. @@ -308,10 +309,8 @@ def update_dataset_from_buffer(self, buffer: List[dict]): buffer (list[dict]): list of episode dictionary buffers to add to dataset """ old_total_episodes = self._data.total_episodes - - self._data.update_from_buffer(buffer, self.spec.data_path) - - new_total_episodes = self._data._total_episodes + self._data.update_episodes(buffer) + new_total_episodes = self._data.total_episodes self._episode_indices = np.append( self._episode_indices, np.arange(old_total_episodes, new_total_episodes) @@ -335,4 +334,47 @@ def __getitem__(self, idx: int) -> EpisodeData: return EpisodeData(**episodes_data[0]) def __len__(self) -> int: - return self.total_episodes + return len(self.episode_indices) + + @property + def total_steps(self): + """Total episodes steps in the Minari dataset.""" + return self._total_steps + + @property + def episode_indices(self) -> np.ndarray: + """Indices of the available episodes to sample within the Minari dataset.""" + return self._episode_indices + + @property + def observation_space(self): + """Original observation space of the environment before flatteining (if this is the case).""" + return self._observation_space + + @property + def action_space(self): + """Original action space of the environment before flatteining (if this is the case).""" + return self._action_space + + @property + def env_spec(self): + """Envspec of the environment that has generated the dataset.""" + return self._env_spec + + @property + def combined_datasets(self) -> List[str]: + """If this Minari dataset is a combination of other subdatasets, return a list with the subdataset names.""" + if self._combined_datasets is None: + return [] + else: + return self._combined_datasets + + @property + def id(self) -> str: + """Name of the Minari dataset.""" + return self._dataset_id + + @property + def minari_version(self) -> str: + """Version of Minari the dataset is compatible with.""" + return self._minari_version diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 3d3cae39..2c65ab69 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -1,5 +1,3 @@ -import importlib.metadata -import json import os from collections import OrderedDict from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -7,82 +5,46 @@ import gymnasium as gym import h5py import numpy as np -from gymnasium.envs.registration import EnvSpec -from packaging.specifiers import InvalidSpecifier, SpecifierSet -from packaging.version import Version -from minari.data_collector import DataCollectorV0 from minari.serialization import deserialize_space -# Use importlib due to circular import when: "from minari import __version__" -__version__ = importlib.metadata.version("minari") - PathLike = Union[str, bytes, os.PathLike] class MinariStorage: def __init__(self, data_path: PathLike): - """Initialize properties of the Minari storage. - - Args: - data_path (str): full path to the `main_data.hdf5` file of the dataset. - """ - self._data_path = data_path - self._extra_data_id = 0 - with h5py.File(self._data_path, "r") as f: - self._env_spec = EnvSpec.from_json(f.attrs["env_spec"]) - - total_episodes = f.attrs["total_episodes"].item() - assert isinstance(total_episodes, int) - self._total_episodes: int = total_episodes - - total_steps = f.attrs["total_steps"].item() - assert isinstance(total_steps, int) - self._total_steps: int = total_steps - - dataset_id = f.attrs["dataset_id"] - assert isinstance(dataset_id, str) - self._dataset_id = dataset_id - - minari_version = f.attrs["minari_version"] - assert isinstance(minari_version, str) - - # Check that the dataset is compatible with the current version of Minari - try: - assert Version(__version__) in SpecifierSet( - minari_version - ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." - self._minari_version = minari_version - except InvalidSpecifier: - print(f"{minari_version} is not a version specifier.") - - self._combined_datasets = f.attrs.get("combined_datasets", default=[]) - - # We will default to using the reconstructed observation and action spaces from the dataset - # and fall back to the env spec env if the action and observation spaces are not both present - # in the dataset. - if "action_space" in f.attrs and "observation_space" in f.attrs: - self._observation_space = deserialize_space( - f.attrs["observation_space"] - ) - self._action_space = deserialize_space(f.attrs["action_space"]) - else: - # Checking if the base library of the environment is present in the environment - entry_point = json.loads(f.attrs["env_spec"])["entry_point"] - lib_full_path = entry_point.split(":")[0] - base_lib = lib_full_path.split(".")[0] - env_name = self._env_spec.id + self._data_path = os.path.join(str(data_path), "main_data.hdf5") - try: - env = gym.make(self._env_spec) - self._observation_space = env.observation_space - self._action_space = env.action_space - env.close() - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Install {base_lib} for loading {env_name} data" - ) from e + @property + def metadata(self) -> Dict: + with h5py.File(self.data_path, "r") as file: + metadata = file.attrs + if "observation_space" in metadata.keys(): + space_serialization = metadata["observation_space"] + assert isinstance(space_serialization, Dict) + metadata["observation_space"] = deserialize_space(space_serialization) + if "action_space" in metadata.keys(): + space_serialization = metadata["action_space"] + assert isinstance(space_serialization, Dict) + metadata["action_space"] = deserialize_space(space_serialization) + + return dict(metadata) + + def update_metadata(self, metadata: Dict): + with h5py.File(self.data_path, "w") as file: + file.attrs.update(metadata) + + def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Optional[Iterable] = None): + if episode_indices is None: + episode_indices = range(self.total_episodes) + if len(metadatas) != len(list(episode_indices)): + raise ValueError("The number of metadatas doesn't match the number of episodes in the dataset.") + + with h5py.File(self.data_path, "w") as file: + for metadata, episode_id in zip(metadatas, episode_indices): + ep_group = file[f"episode_{episode_id}"] + ep_group.attrs.update(metadata) def apply( self, @@ -109,11 +71,12 @@ def apply( "id": ep_group.attrs.get("id"), "total_timesteps": ep_group.attrs.get("total_steps"), "seed": ep_group.attrs.get("seed"), + # TODO: self.metadata can be slow for decode space? Cache spaces? Cache metadata (bad for consistency)? "observations": self._decode_space( - ep_group["observations"], self.observation_space + ep_group["observations"], self.metadata["observation_space"] ), "actions": self._decode_space( - ep_group["actions"], self.action_space + ep_group["actions"], self.metadata["action_space"] ), "rewards": ep_group["rewards"][()], "terminations": ep_group["terminations"][()], @@ -169,10 +132,10 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: "total_timesteps": ep_group.attrs.get("total_steps"), "seed": ep_group.attrs.get("seed"), "observations": self._decode_space( - ep_group["observations"], self.observation_space + ep_group["observations"], self.metadata["observation_space"] ), "actions": self._decode_space( - ep_group["actions"], self.action_space + ep_group["actions"], self.metadata["action_space"] ), "rewards": ep_group["rewards"][()], "terminations": ep_group["terminations"][()], @@ -182,44 +145,10 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: return out - def update_from_collector_env( - self, - collector_env: DataCollectorV0, - new_data_file_path: str, - additional_data_id: int, - ): - - collector_env.save_to_disk(path=new_data_file_path) - - with h5py.File(new_data_file_path, "r", track_order=True) as new_data_file: - new_data_total_episodes = new_data_file.attrs["total_episodes"] - new_data_total_steps = new_data_file.attrs["total_steps"] - - with h5py.File(self.data_path, "a", track_order=True) as file: - last_episode_id = file.attrs["total_episodes"] - for id in range(new_data_total_episodes): - file[f"episode_{last_episode_id + id}"] = h5py.ExternalLink( - f"additional_data_{additional_data_id}.hdf5", f"/episode_{id}" - ) - file[f"episode_{last_episode_id + id}"].attrs.modify( - "id", last_episode_id + id - ) - - self._total_steps = file.attrs["total_steps"] + new_data_total_steps - - # Update metadata of minari dataset - file.attrs.modify( - "total_episodes", last_episode_id + new_data_total_episodes - ) - file.attrs.modify("total_steps", self._total_steps) - self._total_episodes = int(file.attrs["total_episodes"].item()) - - def update_from_buffer(self, buffer: List[dict], data_path: str): + def update_episodes(self, episodes: List[dict]): additional_steps = 0 - with h5py.File(data_path, "a", track_order=True) as file: - last_episode_id = file.attrs["total_episodes"] - for i, eps_buff in enumerate(buffer): - episode_id = last_episode_id + i + with h5py.File(self.data_path, "a", track_order=True) as file: + for eps_buff in episodes: # check episode terminated or truncated assert ( eps_buff["terminations"][-1] or eps_buff["truncations"][-1] @@ -229,128 +158,75 @@ def update_from_buffer(self, buffer: List[dict], data_path: str): ), f"Number of observations {len(eps_buff['observations'])} must have an additional \ element compared to the number of action steps {len(eps_buff['actions'])} \ The initial and final observation must be included" - seed = eps_buff.pop("seed", None) - episode_group = clear_episode_buffer( - eps_buff, file.create_group(f"episode_{episode_id}") - ) - + episode_id = eps_buff["id"] + episode_group = get_h5py_subgroup(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id + if "seed" in eps_buff.keys(): + assert not "seed" in episode_group.attrs.keys() + episode_group.attrs["seed"] = eps_buff["seed"] total_steps = len(eps_buff["actions"]) episode_group.attrs["total_steps"] = total_steps additional_steps += total_steps - if seed is None: - episode_group.attrs["seed"] = str(None) - else: - assert isinstance(seed, int) - episode_group.attrs["seed"] = seed - - # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata - - self._total_steps = file.attrs["total_steps"] + additional_steps - self._total_episodes = last_episode_id + len(buffer) + # TODO: make it append + _add_episode_to_group(eps_buff, episode_group) - file.attrs.modify("total_episodes", self._total_episodes) - file.attrs.modify("total_steps", self._total_steps) - - self._total_episodes = int(file.attrs["total_episodes"].item()) - - @property - def observation_space(self): - """Original observation space of the environment before flatteining (if this is the case).""" - return self._observation_space + total_steps = file.attrs["total_steps"] + additional_steps + total_episodes = len(file.keys()) - @property - def action_space(self): - """Original action space of the environment before flatteining (if this is the case).""" - return self._action_space + file.attrs.modify("total_episodes", total_episodes) + file.attrs.modify("total_steps", total_steps) @property - def data_path(self): + def data_path(self) -> PathLike: """Full path to the `main_data.hdf5` file of the dataset.""" return self._data_path @property - def total_steps(self): - """Total steps recorded in the Minari dataset along all episodes.""" - return self._total_steps - - @property - def total_episodes(self): - """Total episodes recorded in the Minari dataset.""" - return self._total_episodes - - @property - def env_spec(self): - """Envspec of the environment that has generated the dataset.""" - return self._env_spec - - @property - def combined_datasets(self) -> List[str]: - """If this Minari dataset is a combination of other subdatasets, return a list with the subdataset names.""" - if self._combined_datasets is None: - return [] - else: - return self._combined_datasets - - @property - def id(self) -> str: - """Name of the Minari dataset.""" - return self._dataset_id - - @property - def minari_version(self) -> str: - """Version of Minari the dataset is compatible with.""" - return self._minari_version - + def total_episodes(self) -> int: + """Total episodes in the dataset.""" + with h5py.File(self.data_path, "r") as file: + total_episodes = file.attrs["total_episodes"] + assert isinstance(total_episodes, np.ndarray) + total_episodes = total_episodes.item() + assert isinstance(total_episodes, int) + return total_episodes -def clear_episode_buffer(episode_buffer: Dict, episode_group: h5py.Group) -> h5py.Group: - """Save an episode dictionary buffer into an HDF5 episode group recursively. +def get_h5py_subgroup(group: h5py.Group, name: str) -> h5py.Group: + if name in group: + subgroup = group.get(name) + # assert isinstance(subgroup, h5py.Group) + else: + subgroup = group.create_group(name) - Args: - episode_buffer (dict): episode buffer - episode_group (h5py.Group): HDF5 group to store the episode datasets + return subgroup - Returns: - episode group: filled HDF5 episode group - """ +def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): for key, data in episode_buffer.items(): if isinstance(data, dict): - if key in episode_group: - episode_group_to_clear = episode_group[key] - else: - episode_group_to_clear = episode_group.create_group(key) - clear_episode_buffer(data, episode_group_to_clear) + episode_group_to_clear = get_h5py_subgroup(episode_group, key) + _add_episode_to_group(data, episode_group_to_clear) elif all([isinstance(entry, tuple) for entry in data]): # we have a list of tuples, so we need to act appropriately dict_data = { f"_index_{str(i)}": [entry[i] for entry in data] for i, _ in enumerate(data[0]) } - if key in episode_group: - episode_group_to_clear = episode_group[key] - else: - episode_group_to_clear = episode_group.create_group(key) - - clear_episode_buffer(dict_data, episode_group_to_clear) + episode_group_to_clear = get_h5py_subgroup(episode_group, key) + _add_episode_to_group(dict_data, episode_group_to_clear) elif all([isinstance(entry, OrderedDict) for entry in data]): - # we have a list of OrderedDicts, so we need to act appropriately dict_data = { key: [entry[key] for entry in data] for key, value in data[0].items() } - - if key in episode_group: - episode_group_to_clear = episode_group[key] + episode_group_to_clear = get_h5py_subgroup(episode_group, key) + _add_episode_to_group(dict_data, episode_group_to_clear) + else: # leaf data + if isinstance(episode_group, h5py.Dataset): + pass #TODO + elif all(map(lambda elem: isinstance(elem, str), data)): + dtype = h5py.string_dtype(encoding="utf-8") + episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True) else: - episode_group_to_clear = episode_group.create_group(key) - clear_episode_buffer(dict_data, episode_group_to_clear) - elif all(map(lambda elem: isinstance(elem, str), data)): - dtype = h5py.string_dtype(encoding="utf-8") - episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True) - else: - assert np.all(np.logical_not(np.isnan(data))) - # add seed to attributes - episode_group.create_dataset(key, data=data, chunks=True) - - return episode_group + assert np.all(np.logical_not(np.isnan(data))) + episode_group.create_dataset(key, data=data, chunks=True) \ No newline at end of file diff --git a/minari/storage/local.py b/minari/storage/local.py index efa16447..0403df38 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -67,6 +67,7 @@ def list_local_datasets( # Minari datasets must contain the data directory. continue + # TODO: remove hdf5 references main_file_path = os.path.join(datasets_path, dst_id, "data/main_data.hdf5") with h5py.File(main_file_path, "r") as f: metadata = dict(f.attrs.items()) diff --git a/minari/utils.py b/minari/utils.py index 0945c705..7a0048da 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -18,7 +18,7 @@ from minari import DataCollectorV0 from minari.dataset.minari_dataset import MinariDataset -from minari.dataset.minari_storage import clear_episode_buffer +# from minari.dataset.minari_storage import clear_episode_buffer from minari.serialization import serialize_space from minari.storage.datasets_root_dir import get_dataset_path @@ -351,171 +351,171 @@ def get_average_reference_score( return float(mean_ref_score) -def create_dataset_from_buffers( - dataset_id: str, - env: gym.Env, - buffer: List[Dict[str, Union[list, Dict]]], - algorithm_name: Optional[str] = None, - author: Optional[str] = None, - author_email: Optional[str] = None, - code_permalink: Optional[str] = None, - minari_version: Optional[str] = None, - action_space: Optional[gym.spaces.Space] = None, - observation_space: Optional[gym.spaces.Space] = None, - ref_min_score: Optional[float] = None, - ref_max_score: Optional[float] = None, - expert_policy: Optional[Callable[[ObsType], ActType]] = None, - num_episodes_average_score: int = 100, -): - """Create Minari dataset from a list of episode dictionary buffers. - - The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows: - ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. - This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. - - Each episode dictionary buffer must have the following items: - * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation - * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). - * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). - * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). - * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). - - Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - - Args: - dataset_id (str): name id to identify Minari dataset - env (gym.Env): Gymnasium environment used to collect the buffer data - buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data - algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None. - author (Optional[str], optional): author that generated the dataset. Defaults to None. - author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None. - code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None. - ref_min_score (Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy. - Also note that this attribute will be added to the Minari dataset only if `ref_max_score` or `expert_policy` are assigned a valid value other than None. - ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in `MinariDataset.get_normalized_score()`. Default None. - expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to `num_episodes_average_score`. - `ref_max_score` and `expert_policy` can't be passed at the same time. Default to None - num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100. - - Returns: - MinariDataset - """ - # NoneType warnings - if code_permalink is None: - warnings.warn( - "`code_permalink` is set to None. For reproducibility purposes it is highly recommended to link your dataset to versioned code.", - UserWarning, - ) - if author is None: - warnings.warn( - "`author` is set to None. For longevity purposes it is highly recommended to provide an author name.", - UserWarning, - ) - if author_email is None: - warnings.warn( - "`author_email` is set to None. For longevity purposes it is highly recommended to provide an author email, or some other obvious contact information.", - UserWarning, - ) - if minari_version is None: - version = Version(__version__) - release = version.release - # For __version__ = X.Y.Z, set version specifier by default to compatibility with version X.Y or later, but not (X+1).0 or later. - minari_version = f"~={'.'.join(str(x) for x in release[:2])}" - warnings.warn( - f"`minari_version` is set to None. The compatible dataset version specifier for Minari will be set to {minari_version}.", - UserWarning, - ) - # Check if the installed Minari version falls inside the minari_version specifier - try: - assert Version(__version__) in SpecifierSet( - minari_version - ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." - except InvalidSpecifier: - print(f"{minari_version} is not a version specifier.") - - if observation_space is None: - observation_space = env.observation_space - if action_space is None: - action_space = env.action_space - - if expert_policy is not None and ref_max_score is not None: - raise ValueError( - "Can't pass a value for `expert_policy` and `ref_max_score` at the same time." - ) - - dataset_path = get_dataset_path(dataset_id) - - # Check if dataset already exists - if not os.path.exists(dataset_path): - dataset_path = os.path.join(dataset_path, "data") - os.makedirs(dataset_path) - data_path = os.path.join(dataset_path, "main_data.hdf5") - - total_steps = 0 - with h5py.File(data_path, "w", track_order=True) as file: - for i, eps_buff in enumerate(buffer): - # check episode terminated or truncated - assert ( - eps_buff["terminations"][-1] or eps_buff["truncations"][-1] - ), "Each episode must be terminated or truncated before adding it to a Minari dataset" - assert len(eps_buff["actions"]) + 1 == len( - eps_buff["observations"] - ), f"Number of observations {len(eps_buff['observations'])} must have an additional element compared to the number of action steps {len(eps_buff['actions'])}. The initial and final observation must be included" - seed = eps_buff.pop("seed", None) - eps_group = clear_episode_buffer( - eps_buff, file.create_group(f"episode_{i}") - ) - - eps_group.attrs["id"] = i - episode_total_steps = len(eps_buff["actions"]) - eps_group.attrs["total_steps"] = episode_total_steps - total_steps += episode_total_steps - - if seed is None: - eps_group.attrs["seed"] = str(None) - else: - assert isinstance(seed, int) - eps_group.attrs["seed"] = seed - - # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata - - file.attrs["total_episodes"] = len(buffer) - file.attrs["total_steps"] = total_steps - - file.attrs[ - "env_spec" - ] = env.spec.to_json() # pyright: ignore [reportOptionalMemberAccess] - file.attrs["dataset_id"] = dataset_id - - action_space_str = serialize_space(action_space) - observation_space_str = serialize_space(observation_space) - - file.attrs["action_space"] = action_space_str - file.attrs["observation_space"] = observation_space_str - - if expert_policy is not None or ref_max_score is not None: - env = copy.deepcopy(env) - if ref_min_score is None: - ref_min_score = get_average_reference_score( - env, RandomPolicy(env), num_episodes_average_score - ) - - if expert_policy is not None: - ref_max_score = get_average_reference_score( - env, expert_policy, num_episodes_average_score - ) - - file.attrs["ref_max_score"] = ref_max_score - file.attrs["ref_min_score"] = ref_min_score - file.attrs["num_episodes_average_score"] = num_episodes_average_score - - file.attrs["minari_version"] = minari_version - - return MinariDataset(data_path) - else: - raise ValueError( - f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." - ) +# def create_dataset_from_buffers( +# dataset_id: str, +# env: gym.Env, +# buffer: List[Dict[str, Union[list, Dict]]], +# algorithm_name: Optional[str] = None, +# author: Optional[str] = None, +# author_email: Optional[str] = None, +# code_permalink: Optional[str] = None, +# minari_version: Optional[str] = None, +# action_space: Optional[gym.spaces.Space] = None, +# observation_space: Optional[gym.spaces.Space] = None, +# ref_min_score: Optional[float] = None, +# ref_max_score: Optional[float] = None, +# expert_policy: Optional[Callable[[ObsType], ActType]] = None, +# num_episodes_average_score: int = 100, +# ): +# """Create Minari dataset from a list of episode dictionary buffers. + +# The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows: +# ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. +# This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. + +# Each episode dictionary buffer must have the following items: +# * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation +# * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). +# * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). +# * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). +# * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). + +# Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. + +# Args: +# dataset_id (str): name id to identify Minari dataset +# env (gym.Env): Gymnasium environment used to collect the buffer data +# buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data +# algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None. +# author (Optional[str], optional): author that generated the dataset. Defaults to None. +# author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None. +# code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None. +# ref_min_score (Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy. +# Also note that this attribute will be added to the Minari dataset only if `ref_max_score` or `expert_policy` are assigned a valid value other than None. +# ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in `MinariDataset.get_normalized_score()`. Default None. +# expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to `num_episodes_average_score`. +# `ref_max_score` and `expert_policy` can't be passed at the same time. Default to None +# num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100. + +# Returns: +# MinariDataset +# """ +# # NoneType warnings +# if code_permalink is None: +# warnings.warn( +# "`code_permalink` is set to None. For reproducibility purposes it is highly recommended to link your dataset to versioned code.", +# UserWarning, +# ) +# if author is None: +# warnings.warn( +# "`author` is set to None. For longevity purposes it is highly recommended to provide an author name.", +# UserWarning, +# ) +# if author_email is None: +# warnings.warn( +# "`author_email` is set to None. For longevity purposes it is highly recommended to provide an author email, or some other obvious contact information.", +# UserWarning, +# ) +# if minari_version is None: +# version = Version(__version__) +# release = version.release +# # For __version__ = X.Y.Z, set version specifier by default to compatibility with version X.Y or later, but not (X+1).0 or later. +# minari_version = f"~={'.'.join(str(x) for x in release[:2])}" +# warnings.warn( +# f"`minari_version` is set to None. The compatible dataset version specifier for Minari will be set to {minari_version}.", +# UserWarning, +# ) +# # Check if the installed Minari version falls inside the minari_version specifier +# try: +# assert Version(__version__) in SpecifierSet( +# minari_version +# ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." +# except InvalidSpecifier: +# print(f"{minari_version} is not a version specifier.") + +# if observation_space is None: +# observation_space = env.observation_space +# if action_space is None: +# action_space = env.action_space + +# if expert_policy is not None and ref_max_score is not None: +# raise ValueError( +# "Can't pass a value for `expert_policy` and `ref_max_score` at the same time." +# ) + +# dataset_path = get_dataset_path(dataset_id) + +# # Check if dataset already exists +# if not os.path.exists(dataset_path): +# dataset_path = os.path.join(dataset_path, "data") +# os.makedirs(dataset_path) +# data_path = os.path.join(dataset_path, "main_data.hdf5") + +# total_steps = 0 +# with h5py.File(data_path, "w", track_order=True) as file: +# for i, eps_buff in enumerate(buffer): +# # check episode terminated or truncated +# assert ( +# eps_buff["terminations"][-1] or eps_buff["truncations"][-1] +# ), "Each episode must be terminated or truncated before adding it to a Minari dataset" +# assert len(eps_buff["actions"]) + 1 == len( +# eps_buff["observations"] +# ), f"Number of observations {len(eps_buff['observations'])} must have an additional element compared to the number of action steps {len(eps_buff['actions'])}. The initial and final observation must be included" +# seed = eps_buff.pop("seed", None) +# eps_group = clear_episode_buffer( +# eps_buff, file.create_group(f"episode_{i}") +# ) + +# eps_group.attrs["id"] = i +# episode_total_steps = len(eps_buff["actions"]) +# eps_group.attrs["total_steps"] = episode_total_steps +# total_steps += episode_total_steps + +# if seed is None: +# eps_group.attrs["seed"] = str(None) +# else: +# assert isinstance(seed, int) +# eps_group.attrs["seed"] = seed + +# # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata + +# file.attrs["total_episodes"] = len(buffer) +# file.attrs["total_steps"] = total_steps + +# file.attrs[ +# "env_spec" +# ] = env.spec.to_json() # pyright: ignore [reportOptionalMemberAccess] +# file.attrs["dataset_id"] = dataset_id + +# action_space_str = serialize_space(action_space) +# observation_space_str = serialize_space(observation_space) + +# file.attrs["action_space"] = action_space_str +# file.attrs["observation_space"] = observation_space_str + +# if expert_policy is not None or ref_max_score is not None: +# env = copy.deepcopy(env) +# if ref_min_score is None: +# ref_min_score = get_average_reference_score( +# env, RandomPolicy(env), num_episodes_average_score +# ) + +# if expert_policy is not None: +# ref_max_score = get_average_reference_score( +# env, expert_policy, num_episodes_average_score +# ) + +# file.attrs["ref_max_score"] = ref_max_score +# file.attrs["ref_min_score"] = ref_min_score +# file.attrs["num_episodes_average_score"] = num_episodes_average_score + +# file.attrs["minari_version"] = minari_version + +# return MinariDataset(data_path) +# else: +# raise ValueError( +# f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." +# ) def create_dataset_from_collector_env( diff --git a/tests/common.py b/tests/common.py index bdf73727..d8a7f5f6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -570,7 +570,7 @@ def create_dummy_dataset_with_collecter_env_helper( env.reset() # Create Minari dataset and store locally - return minari.create_dataset_from_collector_env( + dataset = minari.create_dataset_from_collector_env( dataset_id=dataset_id, collector_env=env, algorithm_name="random_policy", @@ -579,7 +579,7 @@ def create_dummy_dataset_with_collecter_env_helper( author_email="wdudley@farama.org", ) env.close() - + return dataset def check_episode_data_integrity( episode_data_list: List[EpisodeData], From 1c24f395f0f08302212b3d9047853984ba83163e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 24 Aug 2023 15:45:21 +0200 Subject: [PATCH 02/19] fix bugs --- minari/data_collector/callbacks/step_data.py | 5 +- minari/data_collector/data_collector.py | 51 +++++++++--------- minari/dataset/minari_dataset.py | 14 ++--- minari/dataset/minari_storage.py | 54 +++++++++++--------- 4 files changed, 67 insertions(+), 57 deletions(-) diff --git a/minari/data_collector/callbacks/step_data.py b/minari/data_collector/callbacks/step_data.py index fd849296..0c73db94 100644 --- a/minari/data_collector/callbacks/step_data.py +++ b/minari/data_collector/callbacks/step_data.py @@ -19,6 +19,7 @@ class StepData(TypedDict): "rewards", "truncations", "terminations", + "infos" } @@ -58,8 +59,8 @@ def __call__(self, env, **kwargs): return step_data - The episode groups in the HDF5 file of this Minari dataset will contain a subgroup called `environment_states` with dataset `velocity` and another subgroup called `pose` - with datasets `position` and `orientation` + The Minari dataset will contain a dictionary called `environment_states` with `velocity` value and another dictionary `pose` + with `position` and `orientation` Args: env (gym.Env): current Gymnasium environment. diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 06a0b5a1..953c131a 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -117,13 +117,13 @@ def __init__( os.makedirs(self.datasets_path) self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._storage = MinariStorage(self._tmp_dir.name) assert self.env.spec is not None, "Env Spec is None" - self._storage.update_metadata({ - "action_space": serialize_space(self.dataset_action_space), - "observation_space": serialize_space(self.dataset_observation_space), - "env_spec": self.env.spec.to_json() - }) + self._storage = MinariStorage.new( + os.path.join(self._tmp_dir.name, "main_data"), + action_space=serialize_space(self.dataset_action_space), + observation_space=serialize_space(self.dataset_observation_space), + env_spec=self.env.spec.to_json() + ) def _add_to_episode_buffer( self, @@ -141,8 +141,6 @@ def _add_to_episode_buffer( """ for key, value in step_data.items(): if (not self._record_infos and key == "infos") or (value is None): - # if the step data comes from a reset call: skip actions, rewards, - # terminations, and truncations their values are set to None in the StepDataCallback continue if key not in episode_buffer: @@ -193,6 +191,8 @@ def step( ), "Actions are not in action space." self._step_id += 1 + self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) + if ( self.max_buffer_steps is not None and self._step_id != 0 @@ -200,8 +200,15 @@ def step( ): self._storage.update_episodes(self._buffer) self._buffer = [{"id": self._episode_id}] - - self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) + if step_data["terminations"] or step_data["truncations"]: + self._episode_id += 1 + eps_buff = {"id": self._episode_id} + previous_data = { + "observations": step_data["observations"], + "infos": step_data["infos"] + } + eps_buff = self._add_to_episode_buffer(eps_buff, previous_data) + self._buffer.append(eps_buff) return obs, rew, terminated, truncated, info @@ -220,14 +227,7 @@ def reset( step_data.keys() ), "One or more required keys is missing from 'step-data'" - # Truncate the last episode in the buffer if it is not done - if ( - len(self._buffer) > 0 - and not self._buffer[-1]["terminations"][-1] - and not self._buffer[-1]["truncations"][-1] - ): - self._buffer[-1]["truncations"][-1] = True - + self._validate_buffer() episode_buffer = { "seed": seed if seed else str(None), "id": self._episode_id @@ -236,6 +236,14 @@ def reset( self._buffer.append(episode_buffer) return obs, info + def _validate_buffer(self): + if len(self._buffer) > 0: + if "actions" not in self._buffer[-1].keys(): + self._buffer.pop() + self._episode_id -= 1 + elif not self._buffer[-1]["terminations"][-1]: + self._buffer[-1]["truncations"][-1] = True + def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): """Save all in-memory buffer data and move temporary HDF5 file to a permanent location in disk. @@ -243,11 +251,8 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): path (str): path to store permanent HDF5, i.e: '/home/foo/datasets/data.hdf5' dataset_metadata (Dict, optional): additional metadata to add to HDF5 dataset file as attributes. Defaults to {}. """ - # truncate last episode - if not self._buffer[-1]["terminations"][-1] and not self._buffer[-1]["truncations"][-1]: - self._buffer[-1]["truncations"][-1] = True - - self._storage.update_episodes(self._buffer) #TODO: update add episode to use episode_id + self._validate_buffer() + self._storage.update_episodes(self._buffer) self._buffer.clear() assert ( diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 4bb0b315..9594f7a0 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -152,11 +152,8 @@ def __init__( if episode_indices is None: total_episodes = metadata["total_episodes"] - assert isinstance(total_episodes, np.ndarray) - episode_indices = np.arange(total_episodes.item()) + episode_indices = np.arange(total_episodes) total_steps = metadata["total_steps"] - assert isinstance(total_steps, np.ndarray) - total_steps = total_steps.item() else: total_steps = sum( self._data.apply( @@ -167,7 +164,6 @@ def __init__( assert isinstance(episode_indices, np.ndarray) self._episode_indices: np.ndarray = episode_indices - assert isinstance(total_steps, int) self._total_steps = total_steps assert self._episode_indices is not None @@ -334,12 +330,16 @@ def __getitem__(self, idx: int) -> EpisodeData: return EpisodeData(**episodes_data[0]) def __len__(self) -> int: + return self.total_episodes + + @property + def total_episodes(self) -> int: return len(self.episode_indices) @property - def total_steps(self): + def total_steps(self) -> int: """Total episodes steps in the Minari dataset.""" - return self._total_steps + return int(self._total_steps) @property def episode_indices(self) -> np.ndarray: diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 2c65ab69..604dcd0b 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -14,25 +14,41 @@ class MinariStorage: def __init__(self, data_path: PathLike): - self._data_path = os.path.join(str(data_path), "main_data.hdf5") + base_path, ext = os.path.splitext(str(data_path)) + if ext != "" and ext != ".hdf5": + raise ValueError(f"Only hdf5 extension is supported, found {ext}") + self._data_path = base_path + ".hdf5" + + @classmethod + def new(cls, data_path: PathLike, action_space, observation_space, env_spec): + obj = cls(data_path) + obj.update_metadata({ + "action_space": action_space, + "observation_space": observation_space, + "env_spec": env_spec, + "total_episodes": 0, + "total_steps": 0 + }) + return obj @property def metadata(self) -> Dict: + metadata = {} with h5py.File(self.data_path, "r") as file: - metadata = file.attrs + metadata.update(file.attrs) if "observation_space" in metadata.keys(): space_serialization = metadata["observation_space"] - assert isinstance(space_serialization, Dict) + assert isinstance(space_serialization, str) metadata["observation_space"] = deserialize_space(space_serialization) if "action_space" in metadata.keys(): space_serialization = metadata["action_space"] - assert isinstance(space_serialization, Dict) + assert isinstance(space_serialization, str) metadata["action_space"] = deserialize_space(space_serialization) - return dict(metadata) + return metadata def update_metadata(self, metadata: Dict): - with h5py.File(self.data_path, "w") as file: + with h5py.File(self.data_path, "a") as file: file.attrs.update(metadata) def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Optional[Iterable] = None): @@ -41,7 +57,7 @@ def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Option if len(metadatas) != len(list(episode_indices)): raise ValueError("The number of metadatas doesn't match the number of episodes in the dataset.") - with h5py.File(self.data_path, "w") as file: + with h5py.File(self.data_path, "a") as file: for metadata, episode_id in zip(metadatas, episode_indices): ep_group = file[f"episode_{episode_id}"] ep_group.attrs.update(metadata) @@ -149,22 +165,13 @@ def update_episodes(self, episodes: List[dict]): additional_steps = 0 with h5py.File(self.data_path, "a", track_order=True) as file: for eps_buff in episodes: - # check episode terminated or truncated - assert ( - eps_buff["terminations"][-1] or eps_buff["truncations"][-1] - ), "Each episode must be terminated or truncated before adding it to a Minari dataset" - assert len(eps_buff["actions"]) + 1 == len( - eps_buff["observations"] - ), f"Number of observations {len(eps_buff['observations'])} must have an additional \ - element compared to the number of action steps {len(eps_buff['actions'])} \ - The initial and final observation must be included" - episode_id = eps_buff["id"] + episode_id = eps_buff.pop("id") episode_group = get_h5py_subgroup(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id if "seed" in eps_buff.keys(): assert not "seed" in episode_group.attrs.keys() - episode_group.attrs["seed"] = eps_buff["seed"] - total_steps = len(eps_buff["actions"]) + episode_group.attrs["seed"] = eps_buff.pop("seed") + total_steps = len(eps_buff["rewards"]) episode_group.attrs["total_steps"] = total_steps additional_steps += total_steps @@ -183,19 +190,16 @@ def data_path(self) -> PathLike: return self._data_path @property - def total_episodes(self) -> int: + def total_episodes(self) -> np.int64: """Total episodes in the dataset.""" with h5py.File(self.data_path, "r") as file: total_episodes = file.attrs["total_episodes"] - assert isinstance(total_episodes, np.ndarray) - total_episodes = total_episodes.item() - assert isinstance(total_episodes, int) + assert type(total_episodes) == np.int64 return total_episodes -def get_h5py_subgroup(group: h5py.Group, name: str) -> h5py.Group: +def get_h5py_subgroup(group: h5py.Group, name: str): if name in group: subgroup = group.get(name) - # assert isinstance(subgroup, h5py.Group) else: subgroup = group.create_group(name) From 8d8f7993a31f2c01650a638d74e0b467afa90203 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 24 Aug 2023 16:13:41 +0200 Subject: [PATCH 03/19] remove from hosting.py --- minari/storage/hosting.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index 674055be..d26e9c84 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -6,13 +6,13 @@ import warnings from typing import Dict, List -import h5py from google.cloud import storage # pyright: ignore [reportGeneralTypeIssues] from gymnasium import logger from packaging.specifiers import SpecifierSet from tqdm.auto import tqdm # pyright: ignore [reportMissingModuleSource] from minari.dataset.minari_dataset import parse_dataset_id +from minari.dataset.minari_storage import MinariStorage from minari.storage.datasets_root_dir import get_dataset_path from minari.storage.local import load_dataset @@ -56,8 +56,7 @@ def _upload_local_directory_to_gcs(local_path, bucket, gcs_path): dataset = load_dataset(dataset_id) - with h5py.File(dataset.spec.data_path, "r") as f: - metadata = dict(f.attrs.items()) + metadata = MinariStorage(dataset.spec.data_path).metadata # See https://github.com/googleapis/python-storage/issues/27 for discussion on progress bars _upload_local_directory_to_gcs(str(file_path), bucket, dataset_id) From e6355ce814d302e1646872329ecebe8e94610898 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 3 Sep 2023 18:59:39 -0400 Subject: [PATCH 04/19] fix some tests --- minari/dataset/minari_storage.py | 8 ++++++++ tests/common.py | 12 ++++++++---- tests/utils/test_dataset_creation.py | 10 +++------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 604dcd0b..9cd98c4e 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -197,6 +197,14 @@ def total_episodes(self) -> np.int64: assert type(total_episodes) == np.int64 return total_episodes + @property + def total_steps(self) -> np.int64: + """Total steps in the dataset.""" + with h5py.File(self.data_path, "r") as file: + total_episodes = file.attrs["total_steps"] + assert type(total_episodes) == np.int64 + return total_episodes + def get_h5py_subgroup(group: h5py.Group, name: str): if name in group: subgroup = group.get(name) diff --git a/tests/common.py b/tests/common.py index d8a7f5f6..2f95637b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -461,26 +461,30 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): # verify we have the right number of episodes, available at the right indices assert data.total_episodes == len(episodes) total_steps = 0 + + observation_space = data.metadata["observation_space"] + action_space = data.metadata["action_space"] + # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct for episode in episodes: total_steps += episode["total_timesteps"] _check_space_elem( episode["observations"], - data.observation_space, + observation_space, episode["total_timesteps"] + 1, ) _check_space_elem( - episode["actions"], data.action_space, episode["total_timesteps"] + episode["actions"], action_space, episode["total_timesteps"] ) for i in range(episode["total_timesteps"] + 1): obs = _reconstuct_obs_or_action_at_index_recursive( episode["observations"], i ) - assert data.observation_space.contains(obs) + assert observation_space.contains(obs) for i in range(episode["total_timesteps"]): action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i) - assert data.action_space.contains(action) + assert action_space.contains(action) assert episode["total_timesteps"] == len(episode["rewards"]) assert episode["total_timesteps"] == len(episode["terminations"]) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 924ce12a..17ce6e3d 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -50,15 +50,11 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): env.reset(seed=42) for episode in range(num_episodes): - terminated = False - truncated = False - while not terminated and not truncated: + done = False + while not done: action = env.action_space.sample() # User-defined policy function _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] + done = terminated or truncated env.reset() From d83ba32793b0eee2f12357b6c45631010cf85b7f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 7 Sep 2023 11:47:00 -0400 Subject: [PATCH 05/19] fix problems --- minari/__init__.py | 2 +- minari/data_collector/data_collector.py | 20 +- minari/dataset/minari_dataset.py | 7 +- minari/dataset/minari_storage.py | 71 ++-- minari/storage/datasets_root_dir.py | 2 +- minari/storage/local.py | 2 +- minari/utils.py | 391 +++++++++----------- tests/common.py | 6 +- tests/data_collector/test_data_collector.py | 1 - tests/dataset/test_minari_dataset.py | 45 ++- tests/dataset/test_minari_storage.py | 42 +-- 11 files changed, 271 insertions(+), 318 deletions(-) diff --git a/minari/__init__.py b/minari/__init__.py index a65fe4a3..df7c35ff 100644 --- a/minari/__init__.py +++ b/minari/__init__.py @@ -9,7 +9,7 @@ from minari.storage.local import delete_dataset, list_local_datasets, load_dataset from minari.utils import ( combine_datasets, - # create_dataset_from_buffers, + create_dataset_from_buffers, create_dataset_from_collector_env, get_normalized_score, split_dataset, diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 953c131a..95726dcc 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -15,7 +15,6 @@ StepDataCallback, ) from minari.dataset.minari_storage import MinariStorage -from minari.serialization import serialize_space EpisodeBuffer = Dict[str, Any] # TODO: narrow this down @@ -119,10 +118,10 @@ def __init__( self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) assert self.env.spec is not None, "Env Spec is None" self._storage = MinariStorage.new( - os.path.join(self._tmp_dir.name, "main_data"), - action_space=serialize_space(self.dataset_action_space), - observation_space=serialize_space(self.dataset_observation_space), - env_spec=self.env.spec.to_json() + self._tmp_dir.name, + action_space=self.dataset_action_space, + observation_space=self.dataset_observation_space, + env_spec=self.env.spec ) def _add_to_episode_buffer( @@ -248,7 +247,7 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): """Save all in-memory buffer data and move temporary HDF5 file to a permanent location in disk. Args: - path (str): path to store permanent HDF5, i.e: '/home/foo/datasets/data.hdf5' + path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' dataset_metadata (Dict, optional): additional metadata to add to HDF5 dataset file as attributes. Defaults to {}. """ self._validate_buffer() @@ -269,13 +268,18 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): episode_metadata = self._storage.apply(self._episode_metadata_callback) self._storage.update_episode_metadata(episode_metadata) - shutil.move(str(self._storage.data_path), path) + files = os.listdir(self._storage.data_path) + for file in files: + shutil.move( + os.path.join(self._storage.data_path, file), + os.path.join(path, file), + ) # Reset episode count self._episode_id = 0 def close(self): - """Close the Gymnasium environment. + """Close the DataCollector. Clear buffer and close temporary directory. """ diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 9594f7a0..8d641a35 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -89,11 +89,7 @@ def __init__( """ if isinstance(data, MinariStorage): self._data = data - elif ( - isinstance(data, str) - or isinstance(data, os.PathLike) - or isinstance(data, bytes) - ): + elif isinstance(data, PathLike): self._data = MinariStorage(data) else: raise ValueError(f"Unrecognized type {type(data)} for data") @@ -314,6 +310,7 @@ def update_dataset_from_buffer(self, buffer: List[dict]): self.spec.total_episodes = self._episode_indices.size + # TODO: avoid this self.spec.total_steps = sum( self._data.apply( lambda episode: episode["total_timesteps"], diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 9cd98c4e..607bac29 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -1,4 +1,5 @@ import os +import pathlib from collections import OrderedDict from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -6,35 +7,44 @@ import h5py import numpy as np -from minari.serialization import deserialize_space +from minari.serialization import deserialize_space, serialize_space -PathLike = Union[str, bytes, os.PathLike] +PathLike = Union[str, os.PathLike] class MinariStorage: def __init__(self, data_path: PathLike): - base_path, ext = os.path.splitext(str(data_path)) - if ext != "" and ext != ".hdf5": - raise ValueError(f"Only hdf5 extension is supported, found {ext}") - self._data_path = base_path + ".hdf5" - + if not os.path.exists(data_path) or not os.path.isdir(data_path): + raise ValueError(f"The data path {data_path} doesn't exists") + file_path = os.path.join(str(data_path), "main_data.hdf5") + if not os.path.exists(file_path): + raise ValueError(f"No data found in data path {data_path}") + self._file_path = file_path + @classmethod - def new(cls, data_path: PathLike, action_space, observation_space, env_spec): + def new(cls, data_path: PathLike, action_space, observation_space, env_spec=None): + data_path = pathlib.Path(data_path) + data_path.mkdir(exist_ok=True) + data_path.joinpath("main_data.hdf5").touch(exist_ok=False) + obj = cls(data_path) - obj.update_metadata({ - "action_space": action_space, - "observation_space": observation_space, - "env_spec": env_spec, + metadata = { + "action_space": serialize_space(action_space), + "observation_space": serialize_space(observation_space), "total_episodes": 0, "total_steps": 0 - }) + } + if env_spec is not None: + metadata["env_spec"] = env_spec.to_json() + + obj.update_metadata(metadata) return obj @property def metadata(self) -> Dict: metadata = {} - with h5py.File(self.data_path, "r") as file: + with h5py.File(self._file_path, "r") as file: metadata.update(file.attrs) if "observation_space" in metadata.keys(): space_serialization = metadata["observation_space"] @@ -48,7 +58,7 @@ def metadata(self) -> Dict: return metadata def update_metadata(self, metadata: Dict): - with h5py.File(self.data_path, "a") as file: + with h5py.File(self._file_path, "a") as file: file.attrs.update(metadata) def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Optional[Iterable] = None): @@ -57,7 +67,7 @@ def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Option if len(metadatas) != len(list(episode_indices)): raise ValueError("The number of metadatas doesn't match the number of episodes in the dataset.") - with h5py.File(self.data_path, "a") as file: + with h5py.File(self._file_path, "a") as file: for metadata, episode_id in zip(metadatas, episode_indices): ep_group = file[f"episode_{episode_id}"] ep_group.attrs.update(metadata) @@ -79,7 +89,7 @@ def apply( if episode_indices is None: episode_indices = range(self.total_episodes) out = [] - with h5py.File(self._data_path, "r") as file: + with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] assert isinstance(ep_group, h5py.Group) @@ -87,7 +97,7 @@ def apply( "id": ep_group.attrs.get("id"), "total_timesteps": ep_group.attrs.get("total_steps"), "seed": ep_group.attrs.get("seed"), - # TODO: self.metadata can be slow for decode space? Cache spaces? Cache metadata (bad for consistency)? + # TODO: self.metadata can be slow for decode space? Cache spaces? Cache metadata? "observations": self._decode_space( ep_group["observations"], self.metadata["observation_space"] ), @@ -139,7 +149,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: episodes (List[dict]): list of episodes data """ out = [] - with h5py.File(self._data_path, "r") as file: + with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] out.append( @@ -161,11 +171,20 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: return out - def update_episodes(self, episodes: List[dict]): + def update_episodes(self, episodes: Iterable[dict]): + """Update epsiodes in the storage from a list of episode buffer. + + Args: + episodes (Iterable[dict]): list of episodes buffer. + They must contain the keys specified in EpsiodeData dataclass, except for `id` which is optional. + If `id` is specified and exists, the new data is appended to the one in the storage. + """ additional_steps = 0 - with h5py.File(self.data_path, "a", track_order=True) as file: + with h5py.File(self._file_path, "a", track_order=True) as file: for eps_buff in episodes: - episode_id = eps_buff.pop("id") + total_episodes = len(file.keys()) + episode_id = eps_buff.pop("id", total_episodes) + assert episode_id <= total_episodes, "Invalid episode id; ids must be sequential." episode_group = get_h5py_subgroup(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id if "seed" in eps_buff.keys(): @@ -187,12 +206,12 @@ def update_episodes(self, episodes: List[dict]): @property def data_path(self) -> PathLike: """Full path to the `main_data.hdf5` file of the dataset.""" - return self._data_path + return os.path.dirname(self._file_path) @property def total_episodes(self) -> np.int64: """Total episodes in the dataset.""" - with h5py.File(self.data_path, "r") as file: + with h5py.File(self._file_path, "r") as file: total_episodes = file.attrs["total_episodes"] assert type(total_episodes) == np.int64 return total_episodes @@ -200,7 +219,7 @@ def total_episodes(self) -> np.int64: @property def total_steps(self) -> np.int64: """Total steps in the dataset.""" - with h5py.File(self.data_path, "r") as file: + with h5py.File(self._file_path, "r") as file: total_episodes = file.attrs["total_steps"] assert type(total_episodes) == np.int64 return total_episodes @@ -218,6 +237,8 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): if isinstance(data, dict): episode_group_to_clear = get_h5py_subgroup(episode_group, key) _add_episode_to_group(data, episode_group_to_clear) + elif isinstance(data, int): + import pdb; pdb.set_trace() elif all([isinstance(entry, tuple) for entry in data]): # we have a list of tuples, so we need to act appropriately dict_data = { diff --git a/minari/storage/datasets_root_dir.py b/minari/storage/datasets_root_dir.py index 35a412a1..83e27318 100644 --- a/minari/storage/datasets_root_dir.py +++ b/minari/storage/datasets_root_dir.py @@ -2,7 +2,7 @@ from pathlib import Path -def get_dataset_path(dataset_id): +def get_dataset_path(dataset_id: str) -> Path: """Get the path to a dataset main directory.""" datasets_path = os.environ.get("MINARI_DATASETS_PATH") if datasets_path is not None: diff --git a/minari/storage/local.py b/minari/storage/local.py index 0403df38..3bcea6d0 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -26,7 +26,7 @@ def load_dataset(dataset_id: str, download: bool = False): MinariDataset """ file_path = get_dataset_path(dataset_id) - data_path = os.path.join(file_path, "data", "main_data.hdf5") + data_path = os.path.join(file_path, "data") if not os.path.exists(data_path): if not download: diff --git a/minari/utils.py b/minari/utils.py index 7a0048da..ea9bfeaf 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -18,6 +18,7 @@ from minari import DataCollectorV0 from minari.dataset.minari_dataset import MinariDataset +from minari.dataset.minari_storage import MinariStorage # from minari.dataset.minari_storage import clear_episode_buffer from minari.serialization import serialize_space from minari.storage.datasets_root_dir import get_dataset_path @@ -203,6 +204,7 @@ def __call__(self, observation: ObsType) -> ActType: return self.action_space.sample() +# TODO: factor h5py out def combine_datasets( datasets_to_combine: List[MinariDataset], new_dataset_id: str, copy: bool = False ): @@ -236,13 +238,13 @@ def combine_datasets( if not os.path.exists(new_dataset_path): new_dataset_path = os.path.join(new_dataset_path, "data") os.makedirs(new_dataset_path) - new_data_path = os.path.join(new_dataset_path, "main_data.hdf5") + new_file_path = os.path.join(new_dataset_path, "main_data.hdf5") else: raise ValueError( f"A Minari dataset with ID {new_dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." ) - with h5py.File(new_data_path, "a", track_order=True) as combined_data_file: + with h5py.File(new_file_path, "a", track_order=True) as combined_data_file: combined_data_file.attrs["total_episodes"] = 0 combined_data_file.attrs["total_steps"] = 0 combined_data_file.attrs["dataset_id"] = new_dataset_id @@ -253,8 +255,9 @@ def combine_datasets( for dataset in datasets_to_combine: last_episode_id = combined_data_file.attrs["total_episodes"] + file_path = f"{dataset.spec.data_path}/main_data.hdf5" if copy: - with h5py.File(dataset.spec.data_path, "r") as dataset_file: + with h5py.File(file_path, "r") as dataset_file: for id in range(dataset.total_episodes): dataset_file.copy( dataset_file[f"episode_{id}"], @@ -268,7 +271,7 @@ def combine_datasets( for id in range(dataset.total_episodes): combined_data_file[ f"episode_{last_episode_id + id}" - ] = h5py.ExternalLink(dataset.spec.data_path, f"/episode_{id}") + ] = h5py.ExternalLink(file_path, f"/episode_{id}") combined_data_file[f"episode_{last_episode_id + id}"].attrs.modify( "id", last_episode_id + id ) @@ -283,7 +286,7 @@ def combine_datasets( ) # TODO: list of authors, and emails - with h5py.File(dataset.spec.data_path, "r") as dataset_file: + with h5py.File(file_path, "r") as dataset_file: combined_data_file.attrs.modify("author", dataset_file.attrs["author"]) combined_data_file.attrs.modify( "author_email", dataset_file.attrs["author_email"] @@ -293,7 +296,7 @@ def combine_datasets( combined_data_file.attrs["env_spec"] = combined_dataset_env_spec.to_json() combined_data_file.attrs["minari_version"] = str(minari_version_specifier) - return MinariDataset(new_data_path) + return MinariDataset(new_dataset_path) def split_dataset( @@ -351,171 +354,141 @@ def get_average_reference_score( return float(mean_ref_score) -# def create_dataset_from_buffers( -# dataset_id: str, -# env: gym.Env, -# buffer: List[Dict[str, Union[list, Dict]]], -# algorithm_name: Optional[str] = None, -# author: Optional[str] = None, -# author_email: Optional[str] = None, -# code_permalink: Optional[str] = None, -# minari_version: Optional[str] = None, -# action_space: Optional[gym.spaces.Space] = None, -# observation_space: Optional[gym.spaces.Space] = None, -# ref_min_score: Optional[float] = None, -# ref_max_score: Optional[float] = None, -# expert_policy: Optional[Callable[[ObsType], ActType]] = None, -# num_episodes_average_score: int = 100, -# ): -# """Create Minari dataset from a list of episode dictionary buffers. - -# The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows: -# ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. -# This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. - -# Each episode dictionary buffer must have the following items: -# * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation -# * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). -# * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). -# * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). -# * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). - -# Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - -# Args: -# dataset_id (str): name id to identify Minari dataset -# env (gym.Env): Gymnasium environment used to collect the buffer data -# buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data -# algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None. -# author (Optional[str], optional): author that generated the dataset. Defaults to None. -# author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None. -# code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None. -# ref_min_score (Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy. -# Also note that this attribute will be added to the Minari dataset only if `ref_max_score` or `expert_policy` are assigned a valid value other than None. -# ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in `MinariDataset.get_normalized_score()`. Default None. -# expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to `num_episodes_average_score`. -# `ref_max_score` and `expert_policy` can't be passed at the same time. Default to None -# num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100. - -# Returns: -# MinariDataset -# """ -# # NoneType warnings -# if code_permalink is None: -# warnings.warn( -# "`code_permalink` is set to None. For reproducibility purposes it is highly recommended to link your dataset to versioned code.", -# UserWarning, -# ) -# if author is None: -# warnings.warn( -# "`author` is set to None. For longevity purposes it is highly recommended to provide an author name.", -# UserWarning, -# ) -# if author_email is None: -# warnings.warn( -# "`author_email` is set to None. For longevity purposes it is highly recommended to provide an author email, or some other obvious contact information.", -# UserWarning, -# ) -# if minari_version is None: -# version = Version(__version__) -# release = version.release -# # For __version__ = X.Y.Z, set version specifier by default to compatibility with version X.Y or later, but not (X+1).0 or later. -# minari_version = f"~={'.'.join(str(x) for x in release[:2])}" -# warnings.warn( -# f"`minari_version` is set to None. The compatible dataset version specifier for Minari will be set to {minari_version}.", -# UserWarning, -# ) -# # Check if the installed Minari version falls inside the minari_version specifier -# try: -# assert Version(__version__) in SpecifierSet( -# minari_version -# ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." -# except InvalidSpecifier: -# print(f"{minari_version} is not a version specifier.") - -# if observation_space is None: -# observation_space = env.observation_space -# if action_space is None: -# action_space = env.action_space - -# if expert_policy is not None and ref_max_score is not None: -# raise ValueError( -# "Can't pass a value for `expert_policy` and `ref_max_score` at the same time." -# ) - -# dataset_path = get_dataset_path(dataset_id) - -# # Check if dataset already exists -# if not os.path.exists(dataset_path): -# dataset_path = os.path.join(dataset_path, "data") -# os.makedirs(dataset_path) -# data_path = os.path.join(dataset_path, "main_data.hdf5") - -# total_steps = 0 -# with h5py.File(data_path, "w", track_order=True) as file: -# for i, eps_buff in enumerate(buffer): -# # check episode terminated or truncated -# assert ( -# eps_buff["terminations"][-1] or eps_buff["truncations"][-1] -# ), "Each episode must be terminated or truncated before adding it to a Minari dataset" -# assert len(eps_buff["actions"]) + 1 == len( -# eps_buff["observations"] -# ), f"Number of observations {len(eps_buff['observations'])} must have an additional element compared to the number of action steps {len(eps_buff['actions'])}. The initial and final observation must be included" -# seed = eps_buff.pop("seed", None) -# eps_group = clear_episode_buffer( -# eps_buff, file.create_group(f"episode_{i}") -# ) - -# eps_group.attrs["id"] = i -# episode_total_steps = len(eps_buff["actions"]) -# eps_group.attrs["total_steps"] = episode_total_steps -# total_steps += episode_total_steps - -# if seed is None: -# eps_group.attrs["seed"] = str(None) -# else: -# assert isinstance(seed, int) -# eps_group.attrs["seed"] = seed - -# # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata - -# file.attrs["total_episodes"] = len(buffer) -# file.attrs["total_steps"] = total_steps - -# file.attrs[ -# "env_spec" -# ] = env.spec.to_json() # pyright: ignore [reportOptionalMemberAccess] -# file.attrs["dataset_id"] = dataset_id - -# action_space_str = serialize_space(action_space) -# observation_space_str = serialize_space(observation_space) - -# file.attrs["action_space"] = action_space_str -# file.attrs["observation_space"] = observation_space_str - -# if expert_policy is not None or ref_max_score is not None: -# env = copy.deepcopy(env) -# if ref_min_score is None: -# ref_min_score = get_average_reference_score( -# env, RandomPolicy(env), num_episodes_average_score -# ) - -# if expert_policy is not None: -# ref_max_score = get_average_reference_score( -# env, expert_policy, num_episodes_average_score -# ) - -# file.attrs["ref_max_score"] = ref_max_score -# file.attrs["ref_min_score"] = ref_min_score -# file.attrs["num_episodes_average_score"] = num_episodes_average_score - -# file.attrs["minari_version"] = minari_version - -# return MinariDataset(data_path) -# else: -# raise ValueError( -# f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." -# ) +def create_dataset_from_buffers( + dataset_id: str, + env: gym.Env, + buffer: List[Dict[str, Union[list, Dict]]], + algorithm_name: Optional[str] = None, + author: Optional[str] = None, + author_email: Optional[str] = None, + code_permalink: Optional[str] = None, + minari_version: Optional[str] = None, + action_space: Optional[gym.spaces.Space] = None, + observation_space: Optional[gym.spaces.Space] = None, + ref_min_score: Optional[float] = None, + ref_max_score: Optional[float] = None, + expert_policy: Optional[Callable[[ObsType], ActType]] = None, + num_episodes_average_score: int = 100, +): + """Create Minari dataset from a list of episode dictionary buffers. + + The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows: + ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. + This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. + + Each episode dictionary buffer must have the following items: + * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation + * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). + * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). + * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). + * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). + + Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. + + Args: + dataset_id (str): name id to identify Minari dataset + env (gym.Env): Gymnasium environment used to collect the buffer data + buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data + algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None. + author (Optional[str], optional): author that generated the dataset. Defaults to None. + author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None. + code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None. + ref_min_score (Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy. + Also note that this attribute will be added to the Minari dataset only if `ref_max_score` or `expert_policy` are assigned a valid value other than None. + ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in `MinariDataset.get_normalized_score()`. Default None. + expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to `num_episodes_average_score`. + `ref_max_score` and `expert_policy` can't be passed at the same time. Default to None + num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100. + + Returns: + MinariDataset + """ + # NoneType warnings + if code_permalink is None: + warnings.warn( + "`code_permalink` is set to None. For reproducibility purposes it is highly recommended to link your dataset to versioned code.", + UserWarning, + ) + if author is None: + warnings.warn( + "`author` is set to None. For longevity purposes it is highly recommended to provide an author name.", + UserWarning, + ) + if author_email is None: + warnings.warn( + "`author_email` is set to None. For longevity purposes it is highly recommended to provide an author email, or some other obvious contact information.", + UserWarning, + ) + if minari_version is None: + version = Version(__version__) + release = version.release + # For __version__ = X.Y.Z, set version specifier by default to compatibility with version X.Y or later, but not (X+1).0 or later. + minari_version = f"~={'.'.join(str(x) for x in release[:2])}" + warnings.warn( + f"`minari_version` is set to None. The compatible dataset version specifier for Minari will be set to {minari_version}.", + UserWarning, + ) + # Check if the installed Minari version falls inside the minari_version specifier + try: + assert Version(__version__) in SpecifierSet( + minari_version + ), f"The installed Minari version {__version__} is not contained in the dataset version specifier {minari_version}." + except InvalidSpecifier: + print(f"{minari_version} is not a version specifier.") + + if observation_space is None: + observation_space = env.observation_space + if action_space is None: + action_space = env.action_space + + if expert_policy is not None and ref_max_score is not None: + raise ValueError( + "Can't pass a value for `expert_policy` and `ref_max_score` at the same time." + ) + + dataset_path = get_dataset_path(dataset_id) + + # Check if dataset already exists + if os.path.exists(dataset_path): + raise ValueError( + f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." + ) + dataset_path.mkdir() + + dataset_path = os.path.join(dataset_path, "data") + storage = MinariStorage.new(dataset_path, action_space, observation_space, env_spec=env.spec) + + metadata: Dict[str, Any] = { + "dataset_id": dataset_id, + "minari_version": minari_version + } + if algorithm_name is not None: + metadata["algorithm_name"] = algorithm_name + if author is not None: + metadata["author"] = author + if author_email is not None: + metadata["author_email"] = author_email + if code_permalink is not None: + metadata["code_permalink"] = code_permalink + if expert_policy is not None or ref_max_score is not None: + env = copy.deepcopy(env) + if ref_min_score is None: + ref_min_score = get_average_reference_score( + env, RandomPolicy(env), num_episodes_average_score + ) + + if expert_policy is not None: + ref_max_score = get_average_reference_score( + env, expert_policy, num_episodes_average_score + ) + + metadata["ref_max_score"] = ref_max_score + metadata["ref_min_score"] = ref_min_score + metadata["num_episodes_average_score"] = num_episodes_average_score + + storage.update_metadata(metadata) + storage.update_episodes(buffer) + return MinariDataset(storage) def create_dataset_from_collector_env( @@ -596,54 +569,42 @@ def create_dataset_from_collector_env( dataset_path = os.path.join(collector_env.datasets_path, dataset_id) # Check if dataset already exists - if not os.path.exists(dataset_path): - dataset_path = os.path.join(dataset_path, "data") - os.makedirs(dataset_path) - data_path = os.path.join(dataset_path, "main_data.hdf5") - dataset_metadata: Dict[str, Any] = { - "dataset_id": str(dataset_id), - "algorithm_name": str(algorithm_name), - "author": str(author), - "author_email": str(author_email), - "code_permalink": str(code_permalink), - } - - if expert_policy is not None or ref_max_score is not None: - env = copy.deepcopy(collector_env.env) - if ref_min_score is None: - ref_min_score = get_average_reference_score( - env, RandomPolicy(env), num_episodes_average_score - ) - - if expert_policy is not None: - ref_max_score = get_average_reference_score( - env, expert_policy, num_episodes_average_score - ) - dataset_metadata.update( - { - "ref_max_score": ref_max_score, - "ref_min_score": ref_min_score, - "num_episodes_average_score": num_episodes_average_score, - } - ) - - collector_env.save_to_disk( - data_path, - dataset_metadata={ - "dataset_id": str(dataset_id), - "algorithm_name": str(algorithm_name), - "author": str(author), - "author_email": str(author_email), - "code_permalink": str(code_permalink), - "minari_version": minari_version, - }, - ) - return MinariDataset(data_path) - else: + if os.path.exists(dataset_path): raise ValueError( f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." ) + + dataset_path = os.path.join(dataset_path, "data") + os.makedirs(dataset_path) + dataset_metadata: Dict[str, Any] = { + "dataset_id": dataset_id, + "minari_version": minari_version, + } + if algorithm_name is not None: + dataset_metadata["algorithm_name"] = algorithm_name + if author is not None: + dataset_metadata["author"] = author + if author_email is not None: + dataset_metadata["author_email"] = author_email + if code_permalink is not None: + dataset_metadata["code_permalink"] = code_permalink + if expert_policy is not None or ref_max_score is not None: + env = copy.deepcopy(collector_env.env) + if ref_min_score is None: + ref_min_score = get_average_reference_score( + env, RandomPolicy(env), num_episodes_average_score + ) + + if expert_policy is not None: + ref_max_score = get_average_reference_score( + env, expert_policy, num_episodes_average_score + ) + dataset_metadata["ref_max_score"] = ref_max_score + dataset_metadata["ref_min_score"] = ref_min_score + dataset_metadata["num_episodes_average_score"] = num_episodes_average_score + collector_env.save_to_disk(dataset_path, dataset_metadata) + return MinariDataset(dataset_path) def get_normalized_score( dataset: MinariDataset, returns: Union[float, np.float32] diff --git a/tests/common.py b/tests/common.py index 2f95637b..e4609fe3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -566,10 +566,6 @@ def create_dummy_dataset_with_collecter_env_helper( while not terminated and not truncated: action = env.action_space.sample() # User-defined policy function _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] env.reset() @@ -583,6 +579,8 @@ def create_dummy_dataset_with_collecter_env_helper( author_email="wdudley@farama.org", ) env.close() + + assert dataset_id in minari.list_local_datasets() return dataset def check_episode_data_integrity( diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index 63907a0f..570351c4 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -139,7 +139,6 @@ def test_truncation_without_reset(dataset_id, env_id): else: assert np.array_equal(first_step.observations, last_step.observations) last_step = get_single_step_from_episode(episode, -1) - print(last_step.truncations) assert bool(last_step.truncations) is True # check load and delete local dataset diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 5a2be869..660697f2 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -1,4 +1,5 @@ import copy +import os import re from typing import Any @@ -7,8 +8,10 @@ import pytest import minari +from minari import __version__ from minari import DataCollectorV0, MinariDataset from minari.dataset.minari_dataset import EpisodeData +from minari.dataset.minari_storage import MinariStorage from tests.common import ( check_data_integrity, check_env_recovery, @@ -90,10 +93,6 @@ def test_update_dataset_from_collector_env(dataset_id, env_id): while not terminated and not truncated: action = env.action_space.sample() # User-defined policy function _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] env.reset() @@ -164,10 +163,6 @@ def filter_by_index(episode: Any): while not terminated and not truncated: action = env.action_space.sample() # User-defined policy function _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] env.reset() @@ -196,7 +191,6 @@ def filter_by_index(episode: Any): 18, 19, ) - print(dataset._episode_indices) assert filtered_dataset._data.total_episodes == 20 assert dataset._data.total_episodes == 20 check_env_recovery(env.env, filtered_dataset) @@ -285,7 +279,6 @@ def filter_by_index(episode: Any): 28, 29, ) - print(dataset._episode_indices) assert filtered_dataset._data.total_episodes == 30 assert dataset._data.total_episodes == 30 check_env_recovery(env, filtered_dataset) @@ -333,8 +326,6 @@ def filter_by_index(episode: Any): with pytest.raises(ValueError): episodes = filtered_dataset.sample_episodes(8) - env.close() - @pytest.mark.parametrize( "dataset_id,env_id", @@ -382,8 +373,6 @@ def test_iterate_episodes(dataset_id, env_id): assert length == 10 assert len(dataset) == 10 - env.close() - @pytest.mark.parametrize( "dataset_id,env_id", @@ -466,6 +455,30 @@ def test_update_dataset_from_buffer(dataset_id, env_id): check_data_integrity(dataset._data, dataset.episode_indices) check_env_recovery(env, dataset) - collector_env.close() - + env.close() check_load_and_delete_dataset(dataset_id) + + +def test_missing_env_module(): + data_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets", "dummy-test-v0") + storage = MinariStorage.new( + data_path, + observation_space=gym.spaces.Box(-1, 1), + action_space=gym.spaces.Box(-1, 1), + ) + storage.update_metadata({ + "flatten_observation": False, + "flatten_action": False, + "env_spec": r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""", + "total_episodes": 100, + "total_steps": 1000, + "dataset_id": "dummy-test-v0", + "minari_version": f"=={__version__}" + }) + + with pytest.raises( + ModuleNotFoundError, match="Install dummymodule for loading DummyEnv-v0 data" + ): + MinariDataset(storage.data_path) + + os.remove(data_path) \ No newline at end of file diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index 12be73a1..f87f5c14 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,41 +1 @@ -import os - -import h5py -import pytest - -from minari import __version__ -from minari.dataset.minari_storage import MinariStorage - - -file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets") - - -def _create_dummy_dataset(file_path): - - os.makedirs(file_path, exist_ok=True) - - with h5py.File(os.path.join(file_path, "dummy-test-v0.hdf5"), "w") as f: - - f.attrs["flatten_observation"] = False - f.attrs["flatten_action"] = False - f.attrs[ - "env_spec" - ] = r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""" - f.attrs["total_episodes"] = 100 - f.attrs["total_steps"] = 1000 - f.attrs["dataset_id"] = "dummy-test-v0" - f.attrs["minari_version"] = f"=={__version__}" - - -def test_minari_storage_missing_env_module(): - - file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets") - - _create_dummy_dataset(file_path) - - with pytest.raises( - ModuleNotFoundError, match="Install dummymodule for loading DummyEnv-v0 data" - ): - MinariStorage(os.path.join(file_path, "dummy-test-v0.hdf5")) - - os.remove(os.path.join(file_path, "dummy-test-v0.hdf5")) +# TODO \ No newline at end of file From 2f77f92fcf8bbef2cc6f178b189160902173627a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 9 Sep 2023 19:19:23 -0400 Subject: [PATCH 06/19] remove h5py dependency in local --- minari/storage/local.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/minari/storage/local.py b/minari/storage/local.py index 3bcea6d0..54213ac3 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -3,10 +3,10 @@ import shutil from typing import Dict, Union -import h5py from packaging.specifiers import SpecifierSet from minari.dataset.minari_dataset import MinariDataset, parse_dataset_id +from minari.dataset.minari_storage import MinariStorage from minari.storage import hosting from minari.storage.datasets_root_dir import get_dataset_path @@ -67,25 +67,24 @@ def list_local_datasets( # Minari datasets must contain the data directory. continue - # TODO: remove hdf5 references - main_file_path = os.path.join(datasets_path, dst_id, "data/main_data.hdf5") - with h5py.File(main_file_path, "r") as f: - metadata = dict(f.attrs.items()) - if ("minari_version" not in metadata) or ( - compatible_minari_version - and __version__ not in SpecifierSet(metadata["minari_version"]) + data_path = os.path.join(datasets_path, dst_id, "data") + metadata = MinariStorage(data_path).metadata + if ("minari_version" not in metadata) or ( + compatible_minari_version + and __version__ not in SpecifierSet(metadata["minari_version"]) + ): + continue + env_name, dataset_name, version = parse_dataset_id(dst_id) + dataset = f"{env_name}-{dataset_name}" + if latest_version: + if ( + dataset not in local_datasets + or version > local_datasets[dataset][0] ): - continue - env_name, dataset_name, version = parse_dataset_id(dst_id) - dataset = f"{env_name}-{dataset_name}" - if latest_version: - if ( - dataset not in local_datasets - or version > local_datasets[dataset][0] - ): - local_datasets[dataset] = (version, metadata) - else: - local_datasets[dst_id] = metadata + local_datasets[dataset] = (version, metadata) + else: + local_datasets[dst_id] = metadata + if latest_version: # Return dict = {'dataset_id': metadata} return dict( From e936c7f68c1f188ad74af8886c7c3ae138d12d9d Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 9 Sep 2023 19:19:32 -0400 Subject: [PATCH 07/19] general refactoring --- .../callbacks/episode_metadata.py | 4 +- minari/data_collector/data_collector.py | 2 +- minari/dataset/minari_dataset.py | 3 +- minari/dataset/minari_storage.py | 71 +++++-- minari/utils.py | 9 +- tests/dataset/test_minari_storage.py | 187 +++++++++++++++++- 6 files changed, 255 insertions(+), 21 deletions(-) diff --git a/minari/data_collector/callbacks/episode_metadata.py b/minari/data_collector/callbacks/episode_metadata.py index 5779cdd2..e4ed01cf 100644 --- a/minari/data_collector/callbacks/episode_metadata.py +++ b/minari/data_collector/callbacks/episode_metadata.py @@ -6,7 +6,7 @@ class EpisodeMetadataCallback: """Callback to full episode after saving to hdf5 file as a group. This callback can be overridden to add extra metadata attributes or statistics to - each HDF5 episode group in the Minari dataset. The custom callback can then be + each episode in the Minari dataset. The custom callback can then be passed to the DataCollectorV0 wrapper to the `episode_metadata_callback` argument. TODO: add more default statistics to episode datasets @@ -18,7 +18,7 @@ def __call__(self, episode: Dict): Override this method to add custom attribute metadata to the episode group. Args: - eps_group (h5py.Group): the HDF5 group that contains an episode's datasets + eps_group (dict): the dict that contains an episode's data """ return { "rewards_sum": np.sum(episode["rewards"]), diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 95726dcc..3315bfbb 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -119,8 +119,8 @@ def __init__( assert self.env.spec is not None, "Env Spec is None" self._storage = MinariStorage.new( self._tmp_dir.name, - action_space=self.dataset_action_space, observation_space=self.dataset_observation_space, + action_space=self.dataset_action_space, env_spec=self.env.spec ) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 8d641a35..05f876fa 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -2,7 +2,6 @@ import importlib.metadata import json -import os import re from dataclasses import dataclass, field from typing import Callable, Iterable, Iterator, List, Optional, Union @@ -84,7 +83,7 @@ def __init__( """Initialize properties of the Minari Dataset. Args: - data (Union[MinariStorage, _PathLike]): source of data. + data (Union[MinariStorage, PathLike]): source of data. episode_indices (Optiona[np.ndarray]): slice of episode indices this dataset is pointing to. """ if isinstance(data, MinariStorage): diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 607bac29..8927d007 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import os import pathlib from collections import OrderedDict from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gymnasium as gym +from gymnasium.envs.registration import EnvSpec import h5py import numpy as np @@ -14,24 +17,53 @@ class MinariStorage: + """Class that handles disk access to the data.""" + def __init__(self, data_path: PathLike): + """Initialize a MinariStorage with an existing data path. + To create a new dataset, use the class method `new`. + + Args: + data_path (str or Path): directory containing the data. + + Raises: + ValueError: if the specified path doesn't exist or doesn't contain any data. + + """ if not os.path.exists(data_path) or not os.path.isdir(data_path): - raise ValueError(f"The data path {data_path} doesn't exists") + raise ValueError(f"The data path {data_path} doesn't exist") file_path = os.path.join(str(data_path), "main_data.hdf5") if not os.path.exists(file_path): raise ValueError(f"No data found in data path {data_path}") self._file_path = file_path @classmethod - def new(cls, data_path: PathLike, action_space, observation_space, env_spec=None): + def new( + cls, + data_path: PathLike, + observation_space: gym.Space, + action_space: gym.Space, + env_spec: Optional[EnvSpec] = None + ) -> MinariStorage: + """Class method to create a new data storage. + + Args: + data_path (str or Path): directory where the data will be stored. + observation_space (gymnasium.Space): Gymnasium observation space of the dataset. + action_space (gymnasium.Space): Gymnasium action space of the dataset. + env_spec (EnvSpec): Gymnasium EnvSpec of the environment that generates the dataset. + + Returns: + A new MinariStorage object. + """ data_path = pathlib.Path(data_path) data_path.mkdir(exist_ok=True) data_path.joinpath("main_data.hdf5").touch(exist_ok=False) obj = cls(data_path) metadata = { - "action_space": serialize_space(action_space), "observation_space": serialize_space(observation_space), + "action_space": serialize_space(action_space), "total_episodes": 0, "total_steps": 0 } @@ -43,6 +75,7 @@ def new(cls, data_path: PathLike, action_space, observation_space, env_spec=None @property def metadata(self) -> Dict: + """Metadata of the dataset.""" metadata = {} with h5py.File(self._file_path, "r") as file: metadata.update(file.attrs) @@ -58,14 +91,29 @@ def metadata(self) -> Dict: return metadata def update_metadata(self, metadata: Dict): + """Update the metadata adding/modifying some keys. + + Args: + metadata (dict): dictionary of keys-values to add to the metadata. + """ with h5py.File(self._file_path, "a") as file: file.attrs.update(metadata) def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Optional[Iterable] = None): + """Update the metadata of episodes. + + Args: + metadatas (List[Dict]): list of metadatas, one for each episode. + episode_indices (Iterable, optional): list of episode indices to update. + If not specified, all the episodes are considered. + + Raises: + ValueError: if the lengths of metadatas and episodes to update don't match. + """ if episode_indices is None: episode_indices = range(self.total_episodes) if len(metadatas) != len(list(episode_indices)): - raise ValueError("The number of metadatas doesn't match the number of episodes in the dataset.") + raise ValueError("The number of metadatas doesn't match the number of episodes to update.") with h5py.File(self._file_path, "a") as file: for metadata, episode_id in zip(metadatas, episode_indices): @@ -81,7 +129,7 @@ def apply( Args: function (Callable): function to apply to episodes - episode_indices (Optional[Iterable]): epsiodes id to consider + episode_indices (Optional[Iterable]): episodes id to consider Returns: outs (list): list of outputs returned by the function applied to episodes @@ -185,7 +233,7 @@ def update_episodes(self, episodes: Iterable[dict]): total_episodes = len(file.keys()) episode_id = eps_buff.pop("id", total_episodes) assert episode_id <= total_episodes, "Invalid episode id; ids must be sequential." - episode_group = get_h5py_subgroup(file, f"episode_{episode_id}") + episode_group = _get_from_h5py(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id if "seed" in eps_buff.keys(): assert not "seed" in episode_group.attrs.keys() @@ -224,9 +272,10 @@ def total_steps(self) -> np.int64: assert type(total_episodes) == np.int64 return total_episodes -def get_h5py_subgroup(group: h5py.Group, name: str): +def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: subgroup = group.get(name) + assert isinstance(subgroup, h5py.Group) else: subgroup = group.create_group(name) @@ -235,24 +284,22 @@ def get_h5py_subgroup(group: h5py.Group, name: str): def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): for key, data in episode_buffer.items(): if isinstance(data, dict): - episode_group_to_clear = get_h5py_subgroup(episode_group, key) + episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(data, episode_group_to_clear) - elif isinstance(data, int): - import pdb; pdb.set_trace() elif all([isinstance(entry, tuple) for entry in data]): # we have a list of tuples, so we need to act appropriately dict_data = { f"_index_{str(i)}": [entry[i] for entry in data] for i, _ in enumerate(data[0]) } - episode_group_to_clear = get_h5py_subgroup(episode_group, key) + episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) elif all([isinstance(entry, OrderedDict) for entry in data]): # we have a list of OrderedDicts, so we need to act appropriately dict_data = { key: [entry[key] for entry in data] for key, value in data[0].items() } - episode_group_to_clear = get_h5py_subgroup(episode_group, key) + episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) else: # leaf data if isinstance(episode_group, h5py.Dataset): diff --git a/minari/utils.py b/minari/utils.py index ea9bfeaf..6c465273 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -19,8 +19,6 @@ from minari import DataCollectorV0 from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage -# from minari.dataset.minari_storage import clear_episode_buffer -from minari.serialization import serialize_space from minari.storage.datasets_root_dir import get_dataset_path @@ -456,7 +454,12 @@ def create_dataset_from_buffers( dataset_path.mkdir() dataset_path = os.path.join(dataset_path, "data") - storage = MinariStorage.new(dataset_path, action_space, observation_space, env_spec=env.spec) + storage = MinariStorage.new( + dataset_path, + observation_space=observation_space, + action_space=action_space, + env_spec=env.spec + ) metadata: Dict[str, Any] = { "dataset_id": dataset_id, diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index f87f5c14..b49068e6 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1 +1,186 @@ -# TODO \ No newline at end of file +import tempfile +from minari.dataset.minari_storage import MinariStorage +from gymnasium import spaces +import pytest +import numpy as np + + +@pytest.fixture(autouse=True) +def tmp_dir(): + tmp_dir = tempfile.TemporaryDirectory() + yield tmp_dir.name + tmp_dir.cleanup() + + +def _generate_episode_dict(observation_space: spaces.Space, action_space: spaces.Space, length=25): + terminations = np.zeros(length, dtype=np.bool_) + truncations = np.zeros(length, dtype=np.bool_) + terminated = np.random.randint(2, dtype=np.bool_) + terminations[-1] = terminated + truncations[-1] = not terminated + + return { + "observations": [observation_space.sample() for _ in range(length + 1)], + "actions": [action_space.sample() for _ in range(length)], + "rewards": np.random.randn(length), + "terminations": terminations, + "truncations": truncations + } + +def test_non_existing_data(tmp_dir): + with pytest.raises(ValueError, match="The data path foo doesn't exist"): + MinariStorage("foo") + + with pytest.raises(ValueError, match="No data found in data path"): + MinariStorage(tmp_dir) + + +def test_metadata(tmp_dir): + action_space = spaces.Box(-1, 1) + observation_space = spaces.Box(-1, 1) + storage = MinariStorage.new( + data_path=tmp_dir, + observation_space=observation_space, + action_space=action_space + ) + assert storage.data_path == tmp_dir + + extra_metadata = { + "float": 3.2, + "string": "test-value", + "int": 2, + "bool": True + } + storage.update_metadata(extra_metadata) + + storage_metadata = storage.metadata + assert storage_metadata.keys() == { + 'action_space', + 'bool', + 'float', + 'int', + 'observation_space', + 'string', + 'total_episodes', + 'total_steps' + } + + for key, value in extra_metadata.items(): + assert storage_metadata[key] == value + + storage2 = MinariStorage(tmp_dir) + assert storage_metadata == storage2.metadata + + +def test_add_episodes(tmp_dir): + action_space = spaces.Box(-1, 1, shape=(10,)) + observation_space = spaces.Text(max_length=5) + n_episodes = 10 + steps_per_episode = 25 + episodes = [ + _generate_episode_dict(observation_space, action_space, length=steps_per_episode) + for _ in range(n_episodes) + ] + storage = MinariStorage.new( + data_path=tmp_dir, + observation_space=observation_space, + action_space=action_space + ) + storage.update_episodes(episodes) + del storage + storage = MinariStorage(tmp_dir) + + assert storage.total_episodes == n_episodes + assert storage.total_steps == n_episodes * steps_per_episode + + for i, ep in enumerate(episodes): + storage_ep = storage.get_episodes([i])[0] + + assert np.all(ep["observations"] == storage_ep["observations"]) + assert np.all(ep["actions"] == storage_ep["actions"]) + assert np.all(ep["rewards"] == storage_ep["rewards"]) + assert np.all(ep["terminations"] == storage_ep["terminations"]) + assert np.all(ep["truncations"] == storage_ep["truncations"]) + + +def test_append_episode_chunks(tmp_dir): + action_space = spaces.Discrete(10) + observation_space = spaces.Text(max_length=5) + lens = [10, 7, 15] + chunk1 = _generate_episode_dict(observation_space, action_space, length=lens[0]) + chunk2 = _generate_episode_dict(observation_space, action_space, length=lens[1]) + chunk3 = _generate_episode_dict(observation_space, action_space, length=lens[2]) + chunk1["terminations"][-1] = False + chunk1["truncations"][-1] = False + chunk2["terminations"][-1] = False + chunk2["truncations"][-1] = False + chunk2["observations"] = chunk2["observations"][:-1] + chunk3["observations"] = chunk3["observations"][:-1] + + storage = MinariStorage.new(tmp_dir, observation_space, action_space) + storage.update_episodes([chunk1]) + assert storage.total_episodes == 1 + assert storage.total_steps == lens[0] + + chunk2["id"] = 0 + chunk3["id"] = 0 + storage.update_episodes([chunk2, chunk3]) + assert storage.total_episodes == 1 + assert storage.total_steps == sum(lens) + + +def test_apply(tmp_dir): + action_space = spaces.Box(-1, 1, shape=(10,)) + observation_space = spaces.Text(max_length=5) + n_episodes = 10 + episodes = [ + _generate_episode_dict(observation_space, action_space) + for _ in range(n_episodes) + ] + storage = MinariStorage.new( + data_path=tmp_dir, + observation_space=observation_space, + action_space=action_space + ) + storage.update_episodes(episodes) + + def f(ep): + return ep["actions"].sum() + + episode_indices = [1, 3, 5] + outs = storage.apply(f, episode_indices=episode_indices) + + assert len(outs) == len(episode_indices) + for i, result in zip(episode_indices, outs): + assert np.array(episodes[i]["actions"]).sum() == result + + +def test_episode_metadata(tmp_dir): + action_space = spaces.Box(-1, 1, shape=(10,)) + observation_space = spaces.Text(max_length=5) + n_episodes = 10 + episodes = [ + _generate_episode_dict(observation_space, action_space) + for _ in range(n_episodes) + ] + storage = MinariStorage.new( + data_path=tmp_dir, + observation_space=observation_space, + action_space=action_space + ) + storage.update_episodes(episodes) + + ep_metadatas = [ + {"foo1-1": True, "foo1-2": 7}, + {"foo2-1": 3.14}, + {"foo3-1": "foo", "foo3-2": 42, "foo3-3": "test"}, + ] + + with pytest.raises( + ValueError, + match="The number of metadatas doesn't match the number of episodes to update." + ): + storage.update_episode_metadata(ep_metadatas) + + ep_indices = [1, 4, 5] + storage.update_episode_metadata(ep_metadatas, episode_indices=ep_indices) From ab36b074992e179aabe60a324f56040b0aafa6cc Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 9 Sep 2023 19:34:15 -0400 Subject: [PATCH 08/19] remove PettingZoo leftovers --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aba7b3c9..74a74ed6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: hooks: - id: flake8 args: - - '--per-file-ignores=*/__init__.py:F401 test/all_parameter_combs_test.py:F405 pettingzoo/classic/go/go.py:W605' + - '--per-file-ignores=*/__init__.py:F401' - --extend-ignore=E203 - --max-complexity=205 - --max-line-length=300 @@ -48,7 +48,7 @@ repos: - --count # TODO: Remove ignoring rules D101, D102, D103, D105 - --add-ignore=D100,D107,D101,D102,D103,D105 - exclude: "__init__.py$|^pettingzoo.test|^docs" + exclude: "__init__.py$|^docs" additional_dependencies: ["toml"] - repo: local hooks: From ef28c2fc84f27fd6b42e22b908989f3f7aba5df3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 9 Sep 2023 19:34:29 -0400 Subject: [PATCH 09/19] remove h5py in pydoc --- minari/dataset/minari_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 05f876fa..21324bac 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -202,7 +202,7 @@ def filter_episodes( ``` Args: - condition (Callable[[EpisodeData], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met. + condition (Callable[[EpisodeData], bool]): function that gets in input an EpisodeData object and returns True if certain condition is met. """ def dict_to_episode_data_condition(episode: dict) -> bool: From 856001844f6b3c16d9d7c782c303470f1b0e6fd5 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 12 Sep 2023 13:36:06 -0400 Subject: [PATCH 10/19] refactor minari_storage --- minari/dataset/minari_dataset.py | 31 +++---- minari/dataset/minari_storage.py | 124 ++++++++++++--------------- tests/dataset/test_minari_dataset.py | 25 ------ tests/dataset/test_minari_storage.py | 11 +-- 4 files changed, 69 insertions(+), 122 deletions(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 21324bac..72093722 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -117,29 +117,17 @@ def __init__( self._combined_datasets = metadata.get("combined_datasets", []) - # We will default to using the reconstructed observation and action spaces from the dataset - # and fall back to the env spec env if the action and observation spaces are not both present - # in the dataset. + # By default, we use the observation and action spaces from the dataset and + # we fall back to the env if one of them is not in the dataset. observation_space = metadata.get("observation_space") action_space = metadata.get("action_space") if observation_space is None or action_space is None: - # Checking if the base library of the environment is present in the environment - entry_point = json.loads(env_spec)["entry_point"] - lib_full_path = entry_point.split(":")[0] - base_lib = lib_full_path.split(".")[0] - env_name = self._env_spec.id - - try: - env = gym.make(self._env_spec) - if observation_space is None: - observation_space = env.observation_space - if action_space is None: - action_space = env.action_space - env.close() - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Install {base_lib} for loading {env_name} data" - ) from e + env = self.recover_environment() + if observation_space is None: + observation_space = env.observation_space + if action_space is None: + action_space = env.action_space + env.close() assert isinstance(observation_space, gym.spaces.Space) assert isinstance(action_space, gym.spaces.Space) self._observation_space = observation_space @@ -212,7 +200,8 @@ def dict_to_episode_data_condition(episode: dict) -> bool: dict_to_episode_data_condition, episode_indices=self._episode_indices ) assert self._episode_indices is not None - return MinariDataset(self._data, episode_indices=self._episode_indices[mask]) + filtered_indices = self._episode_indices[list(mask)] + return MinariDataset(self._data, episode_indices=filtered_indices) def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: """Sample n number of episodes from the dataset. diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 8927d007..38e86640 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -99,21 +99,19 @@ def update_metadata(self, metadata: Dict): with h5py.File(self._file_path, "a") as file: file.attrs.update(metadata) - def update_episode_metadata(self, metadatas: List[Dict], episode_indices: Optional[Iterable] = None): + def update_episode_metadata(self, metadatas: Iterable[Dict], episode_indices: Optional[Iterable] = None): """Update the metadata of episodes. Args: - metadatas (List[Dict]): list of metadatas, one for each episode. - episode_indices (Iterable, optional): list of episode indices to update. + metadatas (Iterable[Dict]): metadatas, one for each episode. + episode_indices (Iterable, optional): episode indices to update. If not specified, all the episodes are considered. - Raises: - ValueError: if the lengths of metadatas and episodes to update don't match. + Warning: + In case metadatas and episode_indices have different lengths, the longest is truncated silently. """ if episode_indices is None: episode_indices = range(self.total_episodes) - if len(metadatas) != len(list(episode_indices)): - raise ValueError("The number of metadatas doesn't match the number of episodes to update.") with h5py.File(self._file_path, "a") as file: for metadata, episode_id in zip(metadatas, episode_indices): @@ -124,7 +122,7 @@ def apply( self, function: Callable[[dict], Any], episode_indices: Optional[Iterable] = None, - ) -> List[Any]: + ) -> Iterable[Any]: """Apply a function to a slice of the data. Args: @@ -132,39 +130,21 @@ def apply( episode_indices (Optional[Iterable]): episodes id to consider Returns: - outs (list): list of outputs returned by the function applied to episodes + outs (Iterable): outputs returned by the function applied to episodes """ if episode_indices is None: episode_indices = range(self.total_episodes) - out = [] - with h5py.File(self._file_path, "r") as file: - for ep_idx in episode_indices: - ep_group = file[f"episode_{ep_idx}"] - assert isinstance(ep_group, h5py.Group) - ep_dict = { - "id": ep_group.attrs.get("id"), - "total_timesteps": ep_group.attrs.get("total_steps"), - "seed": ep_group.attrs.get("seed"), - # TODO: self.metadata can be slow for decode space? Cache spaces? Cache metadata? - "observations": self._decode_space( - ep_group["observations"], self.metadata["observation_space"] - ), - "actions": self._decode_space( - ep_group["actions"], self.metadata["action_space"] - ), - "rewards": ep_group["rewards"][()], - "terminations": ep_group["terminations"][()], - "truncations": ep_group["truncations"][()], - } - out.append(function(ep_dict)) - - return out + + ep_dicts = self.get_episodes(episode_indices) + return map(function, ep_dicts) def _decode_space( self, - hdf_ref: Union[h5py.Group, h5py.Dataset], + hdf_ref: Union[h5py.Group, h5py.Dataset, h5py.Datatype], space: gym.spaces.Space, ) -> Union[Dict, Tuple, List, np.ndarray]: + assert not isinstance(hdf_ref, h5py.Datatype) + if isinstance(space, gym.spaces.Tuple): assert isinstance(hdf_ref, h5py.Group) result = [] @@ -176,7 +156,7 @@ def _decode_space( elif isinstance(space, gym.spaces.Dict): assert isinstance(hdf_ref, h5py.Group) result = {} - for key in hdf_ref: + for key in hdf_ref.keys(): result[key] = self._decode_space(hdf_ref[key], space.spaces[key]) return result elif isinstance(space, gym.spaces.Text): @@ -200,22 +180,25 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] - out.append( - { - "id": ep_group.attrs.get("id"), - "total_timesteps": ep_group.attrs.get("total_steps"), - "seed": ep_group.attrs.get("seed"), - "observations": self._decode_space( - ep_group["observations"], self.metadata["observation_space"] - ), - "actions": self._decode_space( - ep_group["actions"], self.metadata["action_space"] - ), - "rewards": ep_group["rewards"][()], - "terminations": ep_group["terminations"][()], - "truncations": ep_group["truncations"][()], - } - ) + assert isinstance(ep_group, h5py.Group) + + ep_dict = { + "id": ep_group.attrs.get("id"), + "total_timesteps": ep_group.attrs.get("total_steps"), + "seed": ep_group.attrs.get("seed"), + "observations": self._decode_space( + ep_group["observations"], self.metadata["observation_space"] # TODO: metadata can be slow + ), + "actions": self._decode_space( + ep_group["actions"], self.metadata["action_space"] + ), + } + for key in {"rewards", "terminations", "truncations"}: + group_value = ep_group[key] + assert isinstance(group_value, h5py.Dataset) + ep_dict[key] = group_value[:] + + out.append(ep_dict) return out @@ -242,10 +225,11 @@ def update_episodes(self, episodes: Iterable[dict]): episode_group.attrs["total_steps"] = total_steps additional_steps += total_steps - # TODO: make it append _add_episode_to_group(eps_buff, episode_group) - total_steps = file.attrs["total_steps"] + additional_steps + current_steps = file.attrs["total_steps"] + assert type(current_steps) == np.int64 + total_steps = current_steps + additional_steps total_episodes = len(file.keys()) file.attrs.modify("total_episodes", total_episodes) @@ -268,9 +252,9 @@ def total_episodes(self) -> np.int64: def total_steps(self) -> np.int64: """Total steps in the dataset.""" with h5py.File(self._file_path, "r") as file: - total_episodes = file.attrs["total_steps"] - assert type(total_episodes) == np.int64 - return total_episodes + total_steps = file.attrs["total_steps"] + assert type(total_steps) == np.int64 + return total_steps def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: @@ -286,27 +270,31 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): if isinstance(data, dict): episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(data, episode_group_to_clear) - elif all([isinstance(entry, tuple) for entry in data]): - # we have a list of tuples, so we need to act appropriately + elif all([isinstance(entry, tuple) for entry in data]): # list of tuples dict_data = { f"_index_{str(i)}": [entry[i] for entry in data] for i, _ in enumerate(data[0]) } episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) - elif all([isinstance(entry, OrderedDict) for entry in data]): - # we have a list of OrderedDicts, so we need to act appropriately + elif all([isinstance(entry, OrderedDict) for entry in data]): # list of OrderedDict dict_data = { - key: [entry[key] for entry in data] for key, value in data[0].items() + key: [entry[key] for entry in data] for key in data[0].keys() } episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) - else: # leaf data - if isinstance(episode_group, h5py.Dataset): - pass #TODO - elif all(map(lambda elem: isinstance(elem, str), data)): + + # leaf data + elif key in episode_group: + dataset = episode_group[key] + assert isinstance(dataset, h5py.Dataset) + dataset.resize((dataset.shape[0] + len(data), *dataset.shape[1:])) + dataset[-len(data):] = data + else: + dtype = None + if all(map(lambda elem: isinstance(elem, str), data)): dtype = h5py.string_dtype(encoding="utf-8") - episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True) - else: - assert np.all(np.logical_not(np.isnan(data))) - episode_group.create_dataset(key, data=data, chunks=True) \ No newline at end of file + dshape = () + if hasattr(data[0], "shape"): + dshape = data[0].shape + episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape)) \ No newline at end of file diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 660697f2..50eef772 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -457,28 +457,3 @@ def test_update_dataset_from_buffer(dataset_id, env_id): env.close() check_load_and_delete_dataset(dataset_id) - - -def test_missing_env_module(): - data_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets", "dummy-test-v0") - storage = MinariStorage.new( - data_path, - observation_space=gym.spaces.Box(-1, 1), - action_space=gym.spaces.Box(-1, 1), - ) - storage.update_metadata({ - "flatten_observation": False, - "flatten_action": False, - "env_spec": r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""", - "total_episodes": 100, - "total_steps": 1000, - "dataset_id": "dummy-test-v0", - "minari_version": f"=={__version__}" - }) - - with pytest.raises( - ModuleNotFoundError, match="Install dummymodule for loading DummyEnv-v0 data" - ): - MinariDataset(storage.data_path) - - os.remove(data_path) \ No newline at end of file diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index b49068e6..311d0aef 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -149,10 +149,11 @@ def f(ep): episode_indices = [1, 3, 5] outs = storage.apply(f, episode_indices=episode_indices) - - assert len(outs) == len(episode_indices) + assert len(episode_indices) == len(list(outs)) for i, result in zip(episode_indices, outs): assert np.array(episodes[i]["actions"]).sum() == result + + def test_episode_metadata(tmp_dir): @@ -176,11 +177,5 @@ def test_episode_metadata(tmp_dir): {"foo3-1": "foo", "foo3-2": 42, "foo3-3": "test"}, ] - with pytest.raises( - ValueError, - match="The number of metadatas doesn't match the number of episodes to update." - ): - storage.update_episode_metadata(ep_metadatas) - ep_indices = [1, 4, 5] storage.update_episode_metadata(ep_metadatas, episode_indices=ep_indices) From 01fe219fa8b4036ea7880f76fa786f5af8cf0ede Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 13 Sep 2023 18:11:48 -0400 Subject: [PATCH 11/19] factor combine_datasets --- minari/data_collector/data_collector.py | 12 ++-- minari/dataset/minari_dataset.py | 5 ++ minari/dataset/minari_storage.py | 51 +++++++++++++--- minari/utils.py | 80 +++++-------------------- 4 files changed, 69 insertions(+), 79 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 3315bfbb..11c97678 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -88,13 +88,13 @@ def __init__( self._step_data_callback = step_data_callback() self._episode_metadata_callback = episode_metadata_callback() - if observation_space is None: - observation_space = self.env.observation_space self.dataset_observation_space = observation_space + if self.dataset_observation_space is None: + self.dataset_observation_space = self.env.observation_space - if action_space is None: - action_space = self.env.action_space self.dataset_action_space = action_space + if self.dataset_action_space is None: + self.dataset_action_space = self.env.action_space self._record_infos = record_infos self.max_buffer_steps = max_buffer_steps @@ -119,8 +119,8 @@ def __init__( assert self.env.spec is not None, "Env Spec is None" self._storage = MinariStorage.new( self._tmp_dir.name, - observation_space=self.dataset_observation_space, - action_space=self.dataset_action_space, + observation_space=observation_space, + action_space=action_space, env_spec=self.env.spec ) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 72093722..29745986 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -363,3 +363,8 @@ def id(self) -> str: def minari_version(self) -> str: """Version of Minari the dataset is compatible with.""" return self._minari_version + + @property + def storage(self) -> MinariStorage: + """MinariStorage managing access to disk.""" + return self._data diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 38e86640..70e552d8 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -41,17 +41,17 @@ def __init__(self, data_path: PathLike): def new( cls, data_path: PathLike, - observation_space: gym.Space, - action_space: gym.Space, + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, env_spec: Optional[EnvSpec] = None ) -> MinariStorage: """Class method to create a new data storage. Args: data_path (str or Path): directory where the data will be stored. - observation_space (gymnasium.Space): Gymnasium observation space of the dataset. - action_space (gymnasium.Space): Gymnasium action space of the dataset. - env_spec (EnvSpec): Gymnasium EnvSpec of the environment that generates the dataset. + observation_space (gymnasium.Space, optional): Gymnasium observation space of the dataset. + action_space (gymnasium.Space, optional): Gymnasium action space of the dataset. + env_spec (EnvSpec, optional): Gymnasium EnvSpec of the environment that generates the dataset. Returns: A new MinariStorage object. @@ -61,12 +61,14 @@ def new( data_path.joinpath("main_data.hdf5").touch(exist_ok=False) obj = cls(data_path) - metadata = { - "observation_space": serialize_space(observation_space), - "action_space": serialize_space(action_space), + metadata: Dict[str, Any] = { "total_episodes": 0, "total_steps": 0 } + if observation_space is not None: + metadata["observation_space"] = serialize_space(observation_space) + if action_space is not None: + metadata["action_space"] = serialize_space(action_space) if env_spec is not None: metadata["env_spec"] = env_spec.to_json() @@ -235,6 +237,39 @@ def update_episodes(self, episodes: Iterable[dict]): file.attrs.modify("total_episodes", total_episodes) file.attrs.modify("total_steps", total_steps) + def update_from_storage(self, storage: MinariStorage, copy: bool = False): + """Update the dataset using another MinariStorage. + + Args: + storage (MinariStorage): the other MinariStorage from which the data will be taken + copy (bool): whether to copy the data or create a link. Default value is false. + """ + with h5py.File(self._file_path, "a", track_order=True) as file: + last_episode_id = file.attrs["total_episodes"] + assert type(last_episode_id) == np.int64 + storage_total_episodes = storage.total_episodes + + if copy: + for id in range(storage.total_episodes): + episode = storage.get_episodes([id]) + episode[0].pop("id") + self.update_episodes(episode) + else: + for id in range(storage_total_episodes): + file[f"episode_{last_episode_id + id}"] = h5py.ExternalLink(storage._file_path, f"/episode_{id}") + file[f"episode_{last_episode_id + id}"].attrs.modify( # TODO: check it doesn't modify original dataset + "id", last_episode_id + id + ) + + file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) + total_steps = file.attrs["total_steps"] + assert type(total_steps) == np.int64 + file.attrs.modify("total_steps", total_steps + storage.total_steps) + + storage_metadata = storage.metadata + file.attrs.modify("author", f'{file.attrs["author"]}; {storage_metadata["author"]}') + file.attrs.modify("author_email", f'{file.attrs["author_email"]}; {storage_metadata["author_email"]}') + @property def data_path(self) -> PathLike: """Full path to the `main_data.hdf5` file of the dataset.""" diff --git a/minari/utils.py b/minari/utils.py index 6c465273..900bba82 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -202,19 +202,18 @@ def __call__(self, observation: ObsType) -> ActType: return self.action_space.sample() -# TODO: factor h5py out def combine_datasets( datasets_to_combine: List[MinariDataset], new_dataset_id: str, copy: bool = False ): """Combine a group of MinariDataset in to a single dataset with its own name id. - A new HDF5 metadata attribute will be added to the new dataset called `combined_datasets`. This will - contain a list of strings with the dataset names that were combined to form this new Minari dataset. + The new dataset will contain a metadata attribute `combined_datasets` containing a list + with the dataset names that were combined to form this new Minari dataset. Args: datasets_to_combine (list[MinariDataset]): list of datasets to be combined new_dataset_id (str): name id for the newly created dataset - copy (bool): whether to copy the data to a new dataset or to create external link (see h5py.ExternalLink) + copy (bool): whether to copy the data to a new dataset or to create a link Returns: combined_dataset (MinariDataset): the resulting MinariDataset @@ -230,71 +229,22 @@ def combine_datasets( datasets_minari_version_specifiers ) - new_dataset_path = get_dataset_path(new_dataset_id) + new_dataset_path = get_dataset_path(new_dataset_id).joinpath("data") + new_storage = MinariStorage.new(new_dataset_path) - # Check if dataset already exists - if not os.path.exists(new_dataset_path): - new_dataset_path = os.path.join(new_dataset_path, "data") - os.makedirs(new_dataset_path) - new_file_path = os.path.join(new_dataset_path, "main_data.hdf5") - else: - raise ValueError( - f"A Minari dataset with ID {new_dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." - ) - - with h5py.File(new_file_path, "a", track_order=True) as combined_data_file: - combined_data_file.attrs["total_episodes"] = 0 - combined_data_file.attrs["total_steps"] = 0 - combined_data_file.attrs["dataset_id"] = new_dataset_id - - combined_data_file.attrs["combined_datasets"] = [ + new_storage.update_metadata({ + "dataset_id": new_dataset_id, + "combined_datasets": [ dataset.spec.dataset_id for dataset in datasets_to_combine - ] - - for dataset in datasets_to_combine: - last_episode_id = combined_data_file.attrs["total_episodes"] - file_path = f"{dataset.spec.data_path}/main_data.hdf5" - if copy: - with h5py.File(file_path, "r") as dataset_file: - for id in range(dataset.total_episodes): - dataset_file.copy( - dataset_file[f"episode_{id}"], - combined_data_file, - name=f"episode_{last_episode_id + id}", - ) - combined_data_file[ - f"episode_{last_episode_id + id}" - ].attrs.modify("id", last_episode_id + id) - else: - for id in range(dataset.total_episodes): - combined_data_file[ - f"episode_{last_episode_id + id}" - ] = h5py.ExternalLink(file_path, f"/episode_{id}") - combined_data_file[f"episode_{last_episode_id + id}"].attrs.modify( - "id", last_episode_id + id - ) - - # Update metadata of minari dataset - combined_data_file.attrs.modify( - "total_episodes", last_episode_id + dataset.total_episodes - ) - combined_data_file.attrs.modify( - "total_steps", - combined_data_file.attrs["total_steps"] + dataset.spec.total_steps, - ) + ], + "env_spec": combined_dataset_env_spec.to_json(), + "minari_version": str(minari_version_specifier) + }) - # TODO: list of authors, and emails - with h5py.File(file_path, "r") as dataset_file: - combined_data_file.attrs.modify("author", dataset_file.attrs["author"]) - combined_data_file.attrs.modify( - "author_email", dataset_file.attrs["author_email"] - ) - - assert combined_dataset_env_spec is not None - combined_data_file.attrs["env_spec"] = combined_dataset_env_spec.to_json() - combined_data_file.attrs["minari_version"] = str(minari_version_specifier) + for dataset in datasets_to_combine: + new_storage.update_from_storage(dataset.storage, copy=copy) - return MinariDataset(new_dataset_path) + return MinariDataset(new_storage) def split_dataset( From 67c0aee3629d262302f90687f4b8a32f316ca02c Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 14 Sep 2023 13:23:00 -0400 Subject: [PATCH 12/19] cache spaces and make them optional in metadata --- minari/data_collector/data_collector.py | 34 +++++----- minari/dataset/minari_storage.py | 87 ++++++++++++++++++------- minari/utils.py | 10 ++- tests/dataset/test_minari_dataset.py | 21 ++++++ tests/utils/test_dataset_combine.py | 4 +- 5 files changed, 109 insertions(+), 47 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 11c97678..3eee6455 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -88,23 +88,6 @@ def __init__( self._step_data_callback = step_data_callback() self._episode_metadata_callback = episode_metadata_callback() - self.dataset_observation_space = observation_space - if self.dataset_observation_space is None: - self.dataset_observation_space = self.env.observation_space - - self.dataset_action_space = action_space - if self.dataset_action_space is None: - self.dataset_action_space = self.env.action_space - - self._record_infos = record_infos - self.max_buffer_steps = max_buffer_steps - - # Initialzie empty buffer - self._buffer: List[EpisodeBuffer] = [] - - self._step_id = -1 - self._episode_id = -1 - # get path to minari datasets directory self.datasets_path = os.environ.get("MINARI_DATASETS_PATH") if self.datasets_path is None: @@ -124,6 +107,23 @@ def __init__( env_spec=self.env.spec ) + if observation_space is None: + observation_space = self.env.observation_space + self.dataset_observation_space = observation_space + + if action_space is None: + action_space = self.env.action_space + self.dataset_action_space = action_space + + self._record_infos = record_infos + self.max_buffer_steps = max_buffer_steps + + # Initialzie empty buffer + self._buffer: List[EpisodeBuffer] = [] + + self._step_id = -1 + self._episode_id = -1 + def _add_to_episode_buffer( self, episode_buffer: EpisodeBuffer, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 70e552d8..a6e9cbe7 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -37,6 +37,9 @@ def __init__(self, data_path: PathLike): raise ValueError(f"No data found in data path {data_path}") self._file_path = file_path + self._observation_space = None + self._action_space = None + @classmethod def new( cls, @@ -55,7 +58,12 @@ def new( Returns: A new MinariStorage object. + + Raises: + ValueError: if you don't specify the env_spec, you need to specify both observation_space and action_space. """ + if env_spec is None and (observation_space is None or action_space is None): + raise ValueError("Since env_spec is not specified, you need to specify both action space and observation space!") data_path = pathlib.Path(data_path) data_path.mkdir(exist_ok=True) data_path.joinpath("main_data.hdf5").touch(exist_ok=False) @@ -67,12 +75,15 @@ def new( } if observation_space is not None: metadata["observation_space"] = serialize_space(observation_space) + obj._observation_space = observation_space if action_space is not None: metadata["action_space"] = serialize_space(action_space) + obj._action_space = action_space if env_spec is not None: metadata["env_spec"] = env_spec.to_json() - obj.update_metadata(metadata) + with h5py.File(obj._file_path, "a") as file: + file.attrs.update(metadata) return obj @property @@ -81,16 +92,10 @@ def metadata(self) -> Dict: metadata = {} with h5py.File(self._file_path, "r") as file: metadata.update(file.attrs) - if "observation_space" in metadata.keys(): - space_serialization = metadata["observation_space"] - assert isinstance(space_serialization, str) - metadata["observation_space"] = deserialize_space(space_serialization) - if "action_space" in metadata.keys(): - space_serialization = metadata["action_space"] - assert isinstance(space_serialization, str) - metadata["action_space"] = deserialize_space(space_serialization) - - return metadata + + metadata["observation_space"] = self.observation_space + metadata["action_space"] = self.action_space + return metadata def update_metadata(self, metadata: Dict): """Update the metadata adding/modifying some keys. @@ -98,6 +103,9 @@ def update_metadata(self, metadata: Dict): Args: metadata (dict): dictionary of keys-values to add to the metadata. """ + forbidden_keys = {"observation_space", "action_space", "env_spec"}.intersection(metadata.keys()) + if forbidden_keys: + raise ValueError(f"You are not allowed to update values for {', '.join(forbidden_keys)}") with h5py.File(self._file_path, "a") as file: file.attrs.update(metadata) @@ -188,12 +196,8 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: "id": ep_group.attrs.get("id"), "total_timesteps": ep_group.attrs.get("total_steps"), "seed": ep_group.attrs.get("seed"), - "observations": self._decode_space( - ep_group["observations"], self.metadata["observation_space"] # TODO: metadata can be slow - ), - "actions": self._decode_space( - ep_group["actions"], self.metadata["action_space"] - ), + "observations": self._decode_space(ep_group["observations"], self.observation_space), + "actions": self._decode_space(ep_group["actions"], self.action_space), } for key in {"rewards", "terminations", "truncations"}: group_value = ep_group[key] @@ -253,6 +257,7 @@ def update_from_storage(self, storage: MinariStorage, copy: bool = False): for id in range(storage.total_episodes): episode = storage.get_episodes([id]) episode[0].pop("id") + episode[0].pop("total_timesteps") self.update_episodes(episode) else: for id in range(storage_total_episodes): @@ -260,15 +265,16 @@ def update_from_storage(self, storage: MinariStorage, copy: bool = False): file[f"episode_{last_episode_id + id}"].attrs.modify( # TODO: check it doesn't modify original dataset "id", last_episode_id + id ) - - file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) - total_steps = file.attrs["total_steps"] - assert type(total_steps) == np.int64 - file.attrs.modify("total_steps", total_steps + storage.total_steps) + file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) + total_steps = file.attrs["total_steps"] + assert type(total_steps) == np.int64 + file.attrs.modify("total_steps", total_steps + storage.total_steps) storage_metadata = storage.metadata - file.attrs.modify("author", f'{file.attrs["author"]}; {storage_metadata["author"]}') - file.attrs.modify("author_email", f'{file.attrs["author_email"]}; {storage_metadata["author_email"]}') + authors = [file.attrs.get("author"), storage_metadata.get("author")] + file.attrs.modify("author", '; '.join([aut for aut in authors if aut is not None])) + emails = [file.attrs.get("author_email"), storage_metadata.get("author_email")] + file.attrs.modify("author_email", '; '.join([e for e in emails if e is not None])) @property def data_path(self) -> PathLike: @@ -290,6 +296,39 @@ def total_steps(self) -> np.int64: total_steps = file.attrs["total_steps"] assert type(total_steps) == np.int64 return total_steps + + @property + def observation_space(self) -> gym.Space: + """Observation Space of the dataset.""" + if self._observation_space is None: + with h5py.File(self._file_path, "r") as file: + if "observation_space" in file.attrs.keys(): + serialized_space = file.attrs["observation_space"] + assert isinstance(serialized_space, str) + self._observation_space = deserialize_space(serialized_space) + else: + env_spec_str = file.attrs.get("env_spec") + assert isinstance(env_spec_str, str) + env_spec = EnvSpec.from_json(env_spec_str) + self._observation_space = gym.make(env_spec).observation_space + return self._observation_space + + @property + def action_space(self) -> gym.Space: + """Action space of the dataset.""" + if self._action_space is None: + with h5py.File(self._file_path, "r") as file: + if "action_space" in file.attrs.keys(): + serialized_space = file.attrs["action_space"] + assert isinstance(serialized_space, str) + self._action_space = deserialize_space(serialized_space) + else: + env_spec_str = file.attrs.get("env_spec") + assert isinstance(env_spec_str, str) + env_spec = EnvSpec.from_json(env_spec_str) + self._action_space = gym.make(env_spec).action_space + + return self._action_space def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: diff --git a/minari/utils.py b/minari/utils.py index 900bba82..a88d965d 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -229,15 +229,19 @@ def combine_datasets( datasets_minari_version_specifiers ) - new_dataset_path = get_dataset_path(new_dataset_id).joinpath("data") - new_storage = MinariStorage.new(new_dataset_path) + + new_dataset_path = get_dataset_path(new_dataset_id) + new_dataset_path.mkdir() + new_storage = MinariStorage.new( + new_dataset_path.joinpath("data"), + env_spec=combined_dataset_env_spec + ) new_storage.update_metadata({ "dataset_id": new_dataset_id, "combined_datasets": [ dataset.spec.dataset_id for dataset in datasets_to_combine ], - "env_spec": combined_dataset_env_spec.to_json(), "minari_version": str(minari_version_specifier) }) diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 50eef772..fb5bf2a4 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -1,9 +1,11 @@ import copy import os import re +import shutil from typing import Any import gymnasium as gym +from gymnasium.envs.registration import EnvSpec import numpy as np import pytest @@ -457,3 +459,22 @@ def test_update_dataset_from_buffer(dataset_id, env_id): env.close() check_load_and_delete_dataset(dataset_id) + + +def test_missing_env_module(): + data_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets", "dummy-test-v0") + class FakeEnvSpec(EnvSpec): + def to_json(self) -> str: + return r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""" + + storage = MinariStorage.new( + data_path, + env_spec=FakeEnvSpec("DummyEnv-v0"), + ) + + with pytest.raises( + ModuleNotFoundError, match="No module named 'dummymodule'" + ): + MinariDataset(storage.data_path) + + shutil.rmtree(data_path) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index b90782f4..eaa67025 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -140,9 +140,7 @@ def test_combine_datasets(): assert isinstance(combined_dataset, MinariDataset) assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids assert combined_dataset.spec.total_episodes == num_datasets * num_episodes - assert combined_dataset.spec.total_steps == sum( - d.spec.total_steps for d in test_datasets - ) + assert combined_dataset.spec.total_steps == sum(d.spec.total_steps for d in test_datasets) _check_env_recovery(gym.make("CartPole-v1"), combined_dataset) # deleting test datasets From 5a85fafb85ea9a6ea8060937b87ddad4e757456a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 19 Sep 2023 13:33:21 -0400 Subject: [PATCH 13/19] test fixes --- minari/data_collector/data_collector.py | 45 +++++- minari/dataset/minari_dataset.py | 144 ++++++------------ minari/dataset/minari_storage.py | 44 +++--- minari/utils.py | 6 +- tests/common.py | 1 - .../callbacks/test_step_data_callback.py | 2 +- tests/dataset/test_dataset_download.py | 2 +- tests/dataset/test_minari_dataset.py | 30 ++-- tests/utils/test_dataset_combine.py | 31 +--- tests/utils/test_dataset_creation.py | 2 +- 10 files changed, 133 insertions(+), 174 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 3eee6455..be53dee1 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -7,6 +7,7 @@ import gymnasium as gym from gymnasium.core import ActType, ObsType +import numpy as np from minari.data_collector.callbacks import ( STEP_DATA_KEYS, @@ -14,6 +15,7 @@ StepData, StepDataCallback, ) +from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage @@ -99,7 +101,6 @@ def __init__( os.makedirs(self.datasets_path) self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - assert self.env.spec is not None, "Env Spec is None" self._storage = MinariStorage.new( self._tmp_dir.name, observation_space=observation_space, @@ -242,13 +243,38 @@ def _validate_buffer(self): self._episode_id -= 1 elif not self._buffer[-1]["terminations"][-1]: self._buffer[-1]["truncations"][-1] = True + + def add_to_dataset(self, dataset: MinariDataset): + """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). + + Args: + dataset (MinariDataset): Dataset to add the data + """ + self._validate_buffer() + self._storage.update_episodes(self._buffer) + self._buffer.clear() + + first_id = dataset.storage.total_episodes + dataset.storage.update_from_storage(self._storage) + if dataset.episode_indices is not None: + new_ids = first_id + np.arange(self._storage.total_episodes) + dataset.episode_indices = np.append(dataset.episode_indices, new_ids) + + self._episode_id = -1 + self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) + self._storage = MinariStorage.new( + self._tmp_dir.name, + observation_space=self._storage.observation_space, + action_space=self._storage.action_space, + env_spec=self.env.spec + ) def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): - """Save all in-memory buffer data and move temporary HDF5 file to a permanent location in disk. + """Save all in-memory buffer data and move temporary files to a permanent location in disk. Args: path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' - dataset_metadata (Dict, optional): additional metadata to add to HDF5 dataset file as attributes. Defaults to {}. + dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}. """ self._validate_buffer() self._storage.update_episodes(self._buffer) @@ -275,8 +301,14 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): os.path.join(path, file), ) - # Reset episode count - self._episode_id = 0 + self._episode_id = -1 + self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) + self._storage = MinariStorage.new( + self._tmp_dir.name, + observation_space=self._storage.observation_space, + action_space=self._storage.action_space, + env_spec=self.env.spec + ) def close(self): """Close the DataCollector. @@ -285,8 +317,5 @@ def close(self): """ super().close() - # Clear buffer self._buffer.clear() - - # Close tmp_dataset.hdf5 shutil.rmtree(self._tmp_dir.name) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 29745986..a518ff85 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -52,7 +52,7 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]: class MinariDatasetSpec: env_spec: EnvSpec total_episodes: int - total_steps: int + total_steps: np.int64 dataset_id: str combined_datasets: List[str] observation_space: gym.Space @@ -93,6 +93,11 @@ def __init__( else: raise ValueError(f"Unrecognized type {type(data)} for data") + if episode_indices is None: + episode_indices = np.arange(self._data.total_episodes) + self._episode_indices: np.ndarray = episode_indices + self._total_steps = None + metadata = self._data.metadata env_spec = metadata["env_spec"] @@ -133,35 +138,6 @@ def __init__( self._observation_space = observation_space self._action_space = action_space - if episode_indices is None: - total_episodes = metadata["total_episodes"] - episode_indices = np.arange(total_episodes) - total_steps = metadata["total_steps"] - else: - total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=episode_indices, - ) - ) - - assert isinstance(episode_indices, np.ndarray) - self._episode_indices: np.ndarray = episode_indices - self._total_steps = total_steps - - assert self._episode_indices is not None - - self.spec = MinariDatasetSpec( - env_spec=self.env_spec, - total_episodes=self._episode_indices.size, - total_steps=total_steps, - dataset_id=self.id, - combined_datasets=self.combined_datasets, - observation_space=self.observation_space, - action_space=self.action_space, - data_path=str(self._data.data_path), - minari_version=str(self.minari_version), - ) self._generator = np.random.default_rng() def recover_environment(self) -> gym.Env: @@ -196,12 +172,12 @@ def filter_episodes( def dict_to_episode_data_condition(episode: dict) -> bool: return condition(EpisodeData(**episode)) - mask = self._data.apply( - dict_to_episode_data_condition, episode_indices=self._episode_indices + mask = self.storage.apply( + dict_to_episode_data_condition, episode_indices=self.episode_indices ) - assert self._episode_indices is not None - filtered_indices = self._episode_indices[list(mask)] - return MinariDataset(self._data, episode_indices=filtered_indices) + assert self.episode_indices is not None + filtered_indices = self.episode_indices[list(mask)] + return MinariDataset(self.storage, episode_indices=filtered_indices) def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: """Sample n number of episodes from the dataset. @@ -212,7 +188,7 @@ def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: indices = self._generator.choice( self.episode_indices, size=n_episodes, replace=False ) - episodes = self._data.get_episodes(indices) + episodes = self.storage.get_episodes(indices) return list(map(lambda data: EpisodeData(**data), episodes)) def iterate_episodes( @@ -231,48 +207,9 @@ def iterate_episodes( assert episode_indices is not None for episode_index in episode_indices: - data = self._data.get_episodes([episode_index])[0] + data = self.storage.get_episodes([episode_index])[0] yield EpisodeData(**data) - # def update_dataset_from_collector_env(self, collector_env: DataCollectorV0): - # """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). - - # This method can be used as a checkpoint when creating a dataset. - # A new HDF5 file will be created with the new dataset file in the same directory as `main_data.hdf5` called - # `additional_data_i.hdf5`. Both datasets are joined together by creating external links to each additional - # episode group: https://docs.h5py.org/en/stable/high/group.html#external-links - - # Args: - # collector_env (DataCollectorV0): Collector environment - # """ - # # check that collector env has the same characteristics as self._env_spec - # new_data_file_path = os.path.join( - # os.path.split(self.spec.data_path)[0], - # f"additional_data_{self._additional_data_id}.hdf5", - # ) - - # old_total_episodes = self._data.total_episodes - - # self._data.update_from_collector_env( - # collector_env, new_data_file_path, self._additional_data_id - # ) - - # new_total_episodes = self._data._total_episodes - - # self._additional_data_id += 1 - - # self._episode_indices = np.append( - # self._episode_indices, np.arange(old_total_episodes, new_total_episodes) - # ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) - - # self.spec.total_episodes = self._episode_indices.size - # self.spec.total_steps = sum( - # self._data.apply( - # lambda episode: episode["total_timesteps"], - # episode_indices=self._episode_indices, - # ) - # ) - def update_dataset_from_buffer(self, buffer: List[dict]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. @@ -288,29 +225,15 @@ def update_dataset_from_buffer(self, buffer: List[dict]): Args: buffer (list[dict]): list of episode dictionary buffers to add to dataset """ - old_total_episodes = self._data.total_episodes - self._data.update_episodes(buffer) - new_total_episodes = self._data.total_episodes - - self._episode_indices = np.append( - self._episode_indices, np.arange(old_total_episodes, new_total_episodes) - ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) - - self.spec.total_episodes = self._episode_indices.size - - # TODO: avoid this - self.spec.total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=self._episode_indices, - ) - ) + first_id = self.storage.total_episodes + self.storage.update_episodes(buffer) + self.episode_indices = np.append(self.episode_indices, first_id + np.arange(len(buffer))) def __iter__(self): return self.iterate_episodes() def __getitem__(self, idx: int) -> EpisodeData: - episodes_data = self._data.get_episodes([self.episode_indices[idx]]) + episodes_data = self.storage.get_episodes([self.episode_indices[idx]]) assert len(episodes_data) == 1 return EpisodeData(**episodes_data[0]) @@ -322,14 +245,29 @@ def total_episodes(self) -> int: return len(self.episode_indices) @property - def total_steps(self) -> int: + def total_steps(self) -> np.int64: """Total episodes steps in the Minari dataset.""" - return int(self._total_steps) + if self._total_steps is None: + if self.episode_indices is None: + self._total_steps = self.storage.total_steps + else: + self._total_steps = sum( + self.storage.apply( + lambda episode: episode["total_timesteps"], + episode_indices=self.episode_indices, + ) + ) + return np.int64(self._total_steps) @property def episode_indices(self) -> np.ndarray: """Indices of the available episodes to sample within the Minari dataset.""" return self._episode_indices + + @episode_indices.setter + def episode_indices(self, new_value: np.ndarray): + self._total_steps = None # invalidate cache + self._episode_indices = new_value @property def observation_space(self): @@ -368,3 +306,17 @@ def minari_version(self) -> str: def storage(self) -> MinariStorage: """MinariStorage managing access to disk.""" return self._data + + @property + def spec(self) -> MinariDatasetSpec: + return MinariDatasetSpec( + env_spec=self.env_spec, + total_episodes=self._episode_indices.size, + total_steps=self.total_steps, + dataset_id=self.id, + combined_datasets=self.combined_datasets, + observation_space=self.observation_space, + action_space=self.action_space, + data_path=str(self.storage.data_path), + minari_version=str(self.minari_version), + ) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index a6e9cbe7..6124fed2 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -3,7 +3,7 @@ import os import pathlib from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import gymnasium as gym from gymnasium.envs.registration import EnvSpec @@ -241,34 +241,37 @@ def update_episodes(self, episodes: Iterable[dict]): file.attrs.modify("total_episodes", total_episodes) file.attrs.modify("total_steps", total_steps) - def update_from_storage(self, storage: MinariStorage, copy: bool = False): + def update_from_storage(self, storage: MinariStorage): """Update the dataset using another MinariStorage. Args: storage (MinariStorage): the other MinariStorage from which the data will be taken - copy (bool): whether to copy the data or create a link. Default value is false. """ + if type(storage) != type(self): + # TODO: relax this constraint. In theory one can use MinariStorage API to udpate + raise ValueError(f"{type(self)} cannot update from {type(storage)}") + with h5py.File(self._file_path, "a", track_order=True) as file: last_episode_id = file.attrs["total_episodes"] assert type(last_episode_id) == np.int64 storage_total_episodes = storage.total_episodes - if copy: - for id in range(storage.total_episodes): - episode = storage.get_episodes([id]) - episode[0].pop("id") - episode[0].pop("total_timesteps") - self.update_episodes(episode) - else: - for id in range(storage_total_episodes): - file[f"episode_{last_episode_id + id}"] = h5py.ExternalLink(storage._file_path, f"/episode_{id}") - file[f"episode_{last_episode_id + id}"].attrs.modify( # TODO: check it doesn't modify original dataset - "id", last_episode_id + id - ) - file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) - total_steps = file.attrs["total_steps"] - assert type(total_steps) == np.int64 - file.attrs.modify("total_steps", total_steps + storage.total_steps) + for id in range(storage.total_episodes): + with h5py.File(storage._file_path, "r", track_order=True) as storage_file: + storage_file.copy( + storage_file[f"episode_{id}"], + file, + name=f"episode_{last_episode_id + id}", + ) + + file[f"episode_{last_episode_id + id}"].attrs.modify( + "id", last_episode_id + id + ) + + file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) + total_steps = file.attrs["total_steps"] + assert type(total_steps) == np.int64 + file.attrs.modify("total_steps", total_steps + storage.total_steps) storage_metadata = storage.metadata authors = [file.attrs.get("author"), storage_metadata.get("author")] @@ -371,4 +374,5 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): dshape = () if hasattr(data[0], "shape"): dshape = data[0].shape - episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape)) \ No newline at end of file + + episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape)) diff --git a/minari/utils.py b/minari/utils.py index a88d965d..4f32449f 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -203,7 +203,7 @@ def __call__(self, observation: ObsType) -> ActType: def combine_datasets( - datasets_to_combine: List[MinariDataset], new_dataset_id: str, copy: bool = False + datasets_to_combine: List[MinariDataset], new_dataset_id: str ): """Combine a group of MinariDataset in to a single dataset with its own name id. @@ -213,7 +213,6 @@ def combine_datasets( Args: datasets_to_combine (list[MinariDataset]): list of datasets to be combined new_dataset_id (str): name id for the newly created dataset - copy (bool): whether to copy the data to a new dataset or to create a link Returns: combined_dataset (MinariDataset): the resulting MinariDataset @@ -229,7 +228,6 @@ def combine_datasets( datasets_minari_version_specifiers ) - new_dataset_path = get_dataset_path(new_dataset_id) new_dataset_path.mkdir() new_storage = MinariStorage.new( @@ -246,7 +244,7 @@ def combine_datasets( }) for dataset in datasets_to_combine: - new_storage.update_from_storage(dataset.storage, copy=copy) + new_storage.update_from_storage(dataset.storage) return MinariDataset(new_storage) diff --git a/tests/common.py b/tests/common.py index e4609fe3..eca597f0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -578,7 +578,6 @@ def create_dummy_dataset_with_collecter_env_helper( author="WillDudley", author_email="wdudley@farama.org", ) - env.close() assert dataset_id in minari.list_local_datasets() return dataset diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index ac484f8b..55764d0a 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -101,7 +101,7 @@ def test_data_collector_step_data_callback(): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) # check that the environment can be recovered from the dataset check_env_recovery_with_subset_spaces( diff --git a/tests/dataset/test_dataset_download.py b/tests/dataset/test_dataset_download.py index fe969bed..95db653f 100644 --- a/tests/dataset/test_dataset_download.py +++ b/tests/dataset/test_dataset_download.py @@ -56,7 +56,7 @@ def test_download_dataset_from_farama_server(dataset_id: str): dataset = minari.load_dataset(dataset_id) assert isinstance(dataset, MinariDataset) - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) minari.delete_dataset(dataset_id) local_datasets = minari.list_local_datasets() diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index fb5bf2a4..516f062a 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -63,9 +63,9 @@ def test_episode_data(space: gym.Space): @pytest.mark.parametrize( "dataset_id,env_id", [ - ("cartpole-test-v0", "CartPole-v1"), - ("dummy-dict-test-v0", "DummyDictEnv-v0"), - ("dummy-box-test-v0", "DummyBoxEnv-v0"), + # ("cartpole-test-v0", "CartPole-v1"), + # ("dummy-dict-test-v0", "DummyDictEnv-v0"), + # ("dummy-box-test-v0", "DummyBoxEnv-v0"), ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), ("dummy-combo-test-v0", "DummyComboEnv-v0"), ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), @@ -98,14 +98,14 @@ def test_update_dataset_from_collector_env(dataset_id, env_id): env.reset() - dataset.update_dataset_from_collector_env(env) + env.add_to_dataset(dataset) assert isinstance(dataset, MinariDataset) assert dataset.total_episodes == num_episodes * 2 assert dataset.spec.total_episodes == num_episodes * 2 assert len(dataset.episode_indices) == num_episodes * 2 - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_env_recovery(env.env, dataset) env.close() @@ -152,7 +152,7 @@ def filter_by_index(episode: Any): assert len(filtered_dataset.episode_indices) == 7 check_data_integrity( - filtered_dataset._data, dataset.episode_indices + filtered_dataset.storage, dataset.episode_indices ) # checks that the underlying episodes are still present in the `MinariStorage` object check_env_recovery(env.env, filtered_dataset) @@ -168,7 +168,7 @@ def filter_by_index(episode: Any): env.reset() - filtered_dataset.update_dataset_from_collector_env(env) + env.add_to_dataset(filtered_dataset) assert isinstance(filtered_dataset, MinariDataset) assert filtered_dataset.total_episodes == 17 @@ -193,8 +193,8 @@ def filter_by_index(episode: Any): 18, 19, ) - assert filtered_dataset._data.total_episodes == 20 - assert dataset._data.total_episodes == 20 + assert filtered_dataset.storage.total_episodes == 20 + assert dataset.storage.total_episodes == 20 check_env_recovery(env.env, filtered_dataset) env.close() @@ -281,11 +281,12 @@ def filter_by_index(episode: Any): 28, 29, ) - assert filtered_dataset._data.total_episodes == 30 - assert dataset._data.total_episodes == 30 + assert filtered_dataset.storage.total_episodes == 30 + assert dataset.storage.total_episodes == 30 check_env_recovery(env, filtered_dataset) check_load_and_delete_dataset(dataset_id) + env.close() @pytest.mark.parametrize( @@ -328,6 +329,8 @@ def filter_by_index(episode: Any): with pytest.raises(ValueError): episodes = filtered_dataset.sample_episodes(8) + env.close() + @pytest.mark.parametrize( "dataset_id,env_id", @@ -353,6 +356,7 @@ def test_iterate_episodes(dataset_id, env_id): dataset = create_dummy_dataset_with_collecter_env_helper( dataset_id, env, num_episodes=num_episodes ) + env.close() episodes = list(dataset.iterate_episodes([1, 3, 5])) @@ -454,10 +458,10 @@ def test_update_dataset_from_buffer(dataset_id, env_id): assert dataset.spec.total_episodes == num_episodes * 2 assert len(dataset.episode_indices) == num_episodes * 2 - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_env_recovery(env, dataset) - env.close() + collector_env.close() check_load_and_delete_dataset(dataset_id) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index eaa67025..63225528 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -119,10 +119,10 @@ def test_combine_datasets(): if "cartpole-combined-test-v0" in local_datasets: minari.delete_dataset("cartpole-combined-test-v0") - # testing without creating a copy combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) + assert test_datasets[1][0].id == 0 assert isinstance(combined_dataset, MinariDataset) assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids assert combined_dataset.spec.total_episodes == num_datasets * num_episodes @@ -131,18 +131,6 @@ def test_combine_datasets(): ) _check_env_recovery(gym.make("CartPole-v1"), combined_dataset) - _check_load_and_delete_dataset("cartpole-combined-test-v0") - - # testing with copy - combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0", copy=True - ) - assert isinstance(combined_dataset, MinariDataset) - assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids - assert combined_dataset.spec.total_episodes == num_datasets * num_episodes - assert combined_dataset.spec.total_steps == sum(d.spec.total_steps for d in test_datasets) - _check_env_recovery(gym.make("CartPole-v1"), combined_dataset) - # deleting test datasets for dataset_id in test_datasets_ids: minari.delete_dataset(dataset_id) @@ -171,32 +159,17 @@ def test_combine_datasets(): minari.load_dataset(dataset_id) for dataset_id in test_datasets_ids ] - # testing without creating a copy combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) assert combined_dataset.spec.env_spec.max_episode_steps is None _check_load_and_delete_dataset("cartpole-combined-test-v0") - # testing with copy - combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0", copy=True - ) - assert combined_dataset.spec.env_spec.max_episode_steps is None - _check_load_and_delete_dataset("cartpole-combined-test-v0") - # Check that we get max(max_episode_steps) when there is no max_episode_steps=None test_datasets.pop() - # testing without creating a copy - combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0" - ) - assert combined_dataset.spec.env_spec.max_episode_steps == 10 - _check_load_and_delete_dataset("cartpole-combined-test-v0") - # testing with copy combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0", copy=True + test_datasets, new_dataset_id="cartpole-combined-test-v0" ) assert combined_dataset.spec.env_spec.max_episode_steps == 10 _check_load_and_delete_dataset("cartpole-combined-test-v0") diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 17ce6e3d..57623713 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -73,7 +73,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset) From 0b41d4729524c851369f4004852197dc2b651e33 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 19 Sep 2023 15:11:29 -0400 Subject: [PATCH 14/19] fix build 3.9 & some linter --- .../callbacks/episode_metadata.py | 5 +- minari/data_collector/callbacks/step_data.py | 2 +- minari/data_collector/data_collector.py | 31 ++--- minari/dataset/episode_data.py | 1 + minari/dataset/minari_dataset.py | 25 ++-- minari/dataset/minari_storage.py | 125 +++++++++++------- minari/storage/local.py | 5 +- minari/utils.py | 38 +++--- tests/common.py | 5 +- tests/dataset/test_minari_dataset.py | 36 ++--- tests/dataset/test_minari_storage.py | 62 ++++----- tests/utils/test_dataset_combine.py | 1 + tests/utils/test_dataset_creation.py | 4 +- 13 files changed, 182 insertions(+), 158 deletions(-) diff --git a/minari/data_collector/callbacks/episode_metadata.py b/minari/data_collector/callbacks/episode_metadata.py index e4ed01cf..4b72abcb 100644 --- a/minari/data_collector/callbacks/episode_metadata.py +++ b/minari/data_collector/callbacks/episode_metadata.py @@ -1,4 +1,5 @@ from typing import Dict + import numpy as np @@ -18,12 +19,12 @@ def __call__(self, episode: Dict): Override this method to add custom attribute metadata to the episode group. Args: - eps_group (dict): the dict that contains an episode's data + episode (dict): the dict that contains an episode's data """ return { "rewards_sum": np.sum(episode["rewards"]), "rewards_mean": np.mean(episode["rewards"]), "rewards_std": np.std(episode["rewards"]), "rewards_max": np.max(episode["rewards"]), - "rewards_min": np.min(episode["rewards"]) + "rewards_min": np.min(episode["rewards"]), } diff --git a/minari/data_collector/callbacks/step_data.py b/minari/data_collector/callbacks/step_data.py index 0c73db94..bcc81eb5 100644 --- a/minari/data_collector/callbacks/step_data.py +++ b/minari/data_collector/callbacks/step_data.py @@ -19,7 +19,7 @@ class StepData(TypedDict): "rewards", "truncations", "terminations", - "infos" + "infos", } diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index be53dee1..35e9d70d 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, SupportsFloat, Type, Union import gymnasium as gym -from gymnasium.core import ActType, ObsType import numpy as np +from gymnasium.core import ActType, ObsType from minari.data_collector.callbacks import ( STEP_DATA_KEYS, @@ -105,7 +105,7 @@ def __init__( self._tmp_dir.name, observation_space=observation_space, action_space=action_space, - env_spec=self.env.spec + env_spec=self.env.spec, ) if observation_space is None: @@ -194,8 +194,8 @@ def step( self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) if ( - self.max_buffer_steps is not None - and self._step_id != 0 + self.max_buffer_steps is not None + and self._step_id != 0 and self._step_id % self.max_buffer_steps == 0 ): self._storage.update_episodes(self._buffer) @@ -205,8 +205,8 @@ def step( eps_buff = {"id": self._episode_id} previous_data = { "observations": step_data["observations"], - "infos": step_data["infos"] - } + "infos": step_data["infos"], + } eps_buff = self._add_to_episode_buffer(eps_buff, previous_data) self._buffer.append(eps_buff) @@ -227,11 +227,8 @@ def reset( step_data.keys() ), "One or more required keys is missing from 'step-data'" - self._validate_buffer() - episode_buffer = { - "seed": seed if seed else str(None), - "id": self._episode_id - } + self._validate_buffer() + episode_buffer = {"seed": seed if seed else str(None), "id": self._episode_id} episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data) self._buffer.append(episode_buffer) return obs, info @@ -243,7 +240,7 @@ def _validate_buffer(self): self._episode_id -= 1 elif not self._buffer[-1]["terminations"][-1]: self._buffer[-1]["truncations"][-1] = True - + def add_to_dataset(self, dataset: MinariDataset): """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). @@ -259,14 +256,14 @@ def add_to_dataset(self, dataset: MinariDataset): if dataset.episode_indices is not None: new_ids = first_id + np.arange(self._storage.total_episodes) dataset.episode_indices = np.append(dataset.episode_indices, new_ids) - + self._episode_id = -1 self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) self._storage = MinariStorage.new( self._tmp_dir.name, observation_space=self._storage.observation_space, action_space=self._storage.action_space, - env_spec=self.env.spec + env_spec=self.env.spec, ) def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): @@ -285,7 +282,7 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): ), "'observation_space' is not allowed as an optional key." assert ( "action_space" not in dataset_metadata.keys() - ), "'action_space' is not allowed as an optional key." + ), "'action_space' is not allowed as an optional key." assert ( "env_spec" not in dataset_metadata.keys() ), "'env_spec' is not allowed as an optional key." @@ -300,14 +297,14 @@ def save_to_disk(self, path: str, dataset_metadata: Dict[str, Any] = {}): os.path.join(self._storage.data_path, file), os.path.join(path, file), ) - + self._episode_id = -1 self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) self._storage = MinariStorage.new( self._tmp_dir.name, observation_space=self._storage.observation_space, action_space=self._storage.action_space, - env_spec=self.env.spec + env_spec=self.env.spec, ) def close(self): diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index db2b02e3..2922dcc9 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Optional + import numpy as np diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index a518ff85..a8a589ae 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -1,7 +1,7 @@ from __future__ import annotations -import importlib.metadata -import json +import importlib.metadata +import os import re from dataclasses import dataclass, field from typing import Callable, Iterable, Iterator, List, Optional, Union @@ -13,9 +13,8 @@ from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version -from minari.dataset.minari_storage import MinariStorage, PathLike from minari.dataset.episode_data import EpisodeData - +from minari.dataset.minari_storage import MinariStorage, PathLike # Use importlib due to circular import when: "from minari import __version__" @@ -88,13 +87,13 @@ def __init__( """ if isinstance(data, MinariStorage): self._data = data - elif isinstance(data, PathLike): + elif isinstance(data, (str, os.PathLike)): self._data = MinariStorage(data) else: raise ValueError(f"Unrecognized type {type(data)} for data") if episode_indices is None: - episode_indices = np.arange(self._data.total_episodes) + episode_indices = np.arange(self._data.total_episodes) self._episode_indices: np.ndarray = episode_indices self._total_steps = None @@ -227,7 +226,9 @@ def update_dataset_from_buffer(self, buffer: List[dict]): """ first_id = self.storage.total_episodes self.storage.update_episodes(buffer) - self.episode_indices = np.append(self.episode_indices, first_id + np.arange(len(buffer))) + self.episode_indices = np.append( + self.episode_indices, first_id + np.arange(len(buffer)) + ) def __iter__(self): return self.iterate_episodes() @@ -239,7 +240,7 @@ def __getitem__(self, idx: int) -> EpisodeData: def __len__(self) -> int: return self.total_episodes - + @property def total_episodes(self) -> int: return len(self.episode_indices) @@ -263,7 +264,7 @@ def total_steps(self) -> np.int64: def episode_indices(self) -> np.ndarray: """Indices of the available episodes to sample within the Minari dataset.""" return self._episode_indices - + @episode_indices.setter def episode_indices(self, new_value: np.ndarray): self._total_steps = None # invalidate cache @@ -301,12 +302,12 @@ def id(self) -> str: def minari_version(self) -> str: """Version of Minari the dataset is compatible with.""" return self._minari_version - + @property def storage(self) -> MinariStorage: - """MinariStorage managing access to disk.""" + """Minari storage managing access to disk.""" return self._data - + @property def spec(self) -> MinariDatasetSpec: return MinariDatasetSpec( diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 6124fed2..31e02441 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -3,12 +3,12 @@ import os import pathlib from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gymnasium as gym -from gymnasium.envs.registration import EnvSpec import h5py import numpy as np +from gymnasium.envs.registration import EnvSpec from minari.serialization import deserialize_space, serialize_space @@ -21,6 +21,7 @@ class MinariStorage: def __init__(self, data_path: PathLike): """Initialize a MinariStorage with an existing data path. + To create a new dataset, use the class method `new`. Args: @@ -46,40 +47,39 @@ def new( data_path: PathLike, observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, - env_spec: Optional[EnvSpec] = None + env_spec: Optional[EnvSpec] = None, ) -> MinariStorage: - """Class method to create a new data storage. + """Class method to create a new data storage. Args: - data_path (str or Path): directory where the data will be stored. + data_path (str or Path): directory where the data will be stored. observation_space (gymnasium.Space, optional): Gymnasium observation space of the dataset. action_space (gymnasium.Space, optional): Gymnasium action space of the dataset. env_spec (EnvSpec, optional): Gymnasium EnvSpec of the environment that generates the dataset. Returns: - A new MinariStorage object. - + A new MinariStorage object. + Raises: ValueError: if you don't specify the env_spec, you need to specify both observation_space and action_space. """ if env_spec is None and (observation_space is None or action_space is None): - raise ValueError("Since env_spec is not specified, you need to specify both action space and observation space!") + raise ValueError( + "Since env_spec is not specified, you need to specify both action space and observation space!" + ) data_path = pathlib.Path(data_path) data_path.mkdir(exist_ok=True) data_path.joinpath("main_data.hdf5").touch(exist_ok=False) - + obj = cls(data_path) - metadata: Dict[str, Any] = { - "total_episodes": 0, - "total_steps": 0 - } + metadata: Dict[str, Any] = {"total_episodes": 0, "total_steps": 0} if observation_space is not None: metadata["observation_space"] = serialize_space(observation_space) obj._observation_space = observation_space if action_space is not None: metadata["action_space"] = serialize_space(action_space) obj._action_space = action_space - if env_spec is not None: + if env_spec is not None: metadata["env_spec"] = env_spec.to_json() with h5py.File(obj._file_path, "a") as file: @@ -93,36 +93,42 @@ def metadata(self) -> Dict: with h5py.File(self._file_path, "r") as file: metadata.update(file.attrs) - metadata["observation_space"] = self.observation_space + metadata["observation_space"] = self.observation_space metadata["action_space"] = self.action_space return metadata - + def update_metadata(self, metadata: Dict): """Update the metadata adding/modifying some keys. - + Args: metadata (dict): dictionary of keys-values to add to the metadata. """ - forbidden_keys = {"observation_space", "action_space", "env_spec"}.intersection(metadata.keys()) + forbidden_keys = {"observation_space", "action_space", "env_spec"}.intersection( + metadata.keys() + ) if forbidden_keys: - raise ValueError(f"You are not allowed to update values for {', '.join(forbidden_keys)}") + raise ValueError( + f"You are not allowed to update values for {', '.join(forbidden_keys)}" + ) with h5py.File(self._file_path, "a") as file: file.attrs.update(metadata) - def update_episode_metadata(self, metadatas: Iterable[Dict], episode_indices: Optional[Iterable] = None): + def update_episode_metadata( + self, metadatas: Iterable[Dict], episode_indices: Optional[Iterable] = None + ): """Update the metadata of episodes. Args: metadatas (Iterable[Dict]): metadatas, one for each episode. - episode_indices (Iterable, optional): episode indices to update. + episode_indices (Iterable, optional): episode indices to update. If not specified, all the episodes are considered. - + Warning: In case metadatas and episode_indices have different lengths, the longest is truncated silently. """ if episode_indices is None: episode_indices = range(self.total_episodes) - + with h5py.File(self._file_path, "a") as file: for metadata, episode_id in zip(metadatas, episode_indices): ep_group = file[f"episode_{episode_id}"] @@ -144,7 +150,7 @@ def apply( """ if episode_indices is None: episode_indices = range(self.total_episodes) - + ep_dicts = self.get_episodes(episode_indices) return map(function, ep_dicts) @@ -191,13 +197,17 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] assert isinstance(ep_group, h5py.Group) - + ep_dict = { "id": ep_group.attrs.get("id"), "total_timesteps": ep_group.attrs.get("total_steps"), "seed": ep_group.attrs.get("seed"), - "observations": self._decode_space(ep_group["observations"], self.observation_space), - "actions": self._decode_space(ep_group["actions"], self.action_space), + "observations": self._decode_space( + ep_group["observations"], self.observation_space + ), + "actions": self._decode_space( + ep_group["actions"], self.action_space + ), } for key in {"rewards", "terminations", "truncations"}: group_value = ep_group[key] @@ -214,18 +224,20 @@ def update_episodes(self, episodes: Iterable[dict]): Args: episodes (Iterable[dict]): list of episodes buffer. They must contain the keys specified in EpsiodeData dataclass, except for `id` which is optional. - If `id` is specified and exists, the new data is appended to the one in the storage. + If `id` is specified and exists, the new data is appended to the one in the storage. """ additional_steps = 0 with h5py.File(self._file_path, "a", track_order=True) as file: for eps_buff in episodes: total_episodes = len(file.keys()) episode_id = eps_buff.pop("id", total_episodes) - assert episode_id <= total_episodes, "Invalid episode id; ids must be sequential." + assert ( + episode_id <= total_episodes + ), "Invalid episode id; ids must be sequential." episode_group = _get_from_h5py(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id if "seed" in eps_buff.keys(): - assert not "seed" in episode_group.attrs.keys() + assert "seed" not in episode_group.attrs.keys() episode_group.attrs["seed"] = eps_buff.pop("seed") total_steps = len(eps_buff["rewards"]) episode_group.attrs["total_steps"] = total_steps @@ -248,7 +260,7 @@ def update_from_storage(self, storage: MinariStorage): storage (MinariStorage): the other MinariStorage from which the data will be taken """ if type(storage) != type(self): - # TODO: relax this constraint. In theory one can use MinariStorage API to udpate + # TODO: relax this constraint. In theory one can use MinariStorage API to update raise ValueError(f"{type(self)} cannot update from {type(storage)}") with h5py.File(self._file_path, "a", track_order=True) as file: @@ -257,27 +269,38 @@ def update_from_storage(self, storage: MinariStorage): storage_total_episodes = storage.total_episodes for id in range(storage.total_episodes): - with h5py.File(storage._file_path, "r", track_order=True) as storage_file: + with h5py.File( + storage._file_path, "r", track_order=True + ) as storage_file: storage_file.copy( - storage_file[f"episode_{id}"], - file, - name=f"episode_{last_episode_id + id}", - ) + storage_file[f"episode_{id}"], + file, + name=f"episode_{last_episode_id + id}", + ) file[f"episode_{last_episode_id + id}"].attrs.modify( "id", last_episode_id + id ) - file.attrs.modify("total_episodes", last_episode_id + storage_total_episodes) + file.attrs.modify( + "total_episodes", last_episode_id + storage_total_episodes + ) total_steps = file.attrs["total_steps"] assert type(total_steps) == np.int64 file.attrs.modify("total_steps", total_steps + storage.total_steps) storage_metadata = storage.metadata authors = [file.attrs.get("author"), storage_metadata.get("author")] - file.attrs.modify("author", '; '.join([aut for aut in authors if aut is not None])) - emails = [file.attrs.get("author_email"), storage_metadata.get("author_email")] - file.attrs.modify("author_email", '; '.join([e for e in emails if e is not None])) + file.attrs.modify( + "author", "; ".join([aut for aut in authors if aut is not None]) + ) + emails = [ + file.attrs.get("author_email"), + storage_metadata.get("author_email"), + ] + file.attrs.modify( + "author_email", "; ".join([e for e in emails if e is not None]) + ) @property def data_path(self) -> PathLike: @@ -299,7 +322,7 @@ def total_steps(self) -> np.int64: total_steps = file.attrs["total_steps"] assert type(total_steps) == np.int64 return total_steps - + @property def observation_space(self) -> gym.Space: """Observation Space of the dataset.""" @@ -333,6 +356,7 @@ def action_space(self) -> gym.Space: return self._action_space + def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: subgroup = group.get(name) @@ -342,6 +366,7 @@ def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: return subgroup + def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): for key, data in episode_buffer.items(): if isinstance(data, dict): @@ -354,19 +379,19 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): } episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) - elif all([isinstance(entry, OrderedDict) for entry in data]): # list of OrderedDict - dict_data = { - key: [entry[key] for entry in data] for key in data[0].keys() - } + elif all( + [isinstance(entry, OrderedDict) for entry in data] + ): # list of OrderedDict + dict_data = {key: [entry[key] for entry in data] for key in data[0].keys()} episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) - + # leaf data elif key in episode_group: dataset = episode_group[key] assert isinstance(dataset, h5py.Dataset) dataset.resize((dataset.shape[0] + len(data), *dataset.shape[1:])) - dataset[-len(data):] = data + dataset[-len(data) :] = data else: dtype = None if all(map(lambda elem: isinstance(elem, str), data)): @@ -374,5 +399,7 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): dshape = () if hasattr(data[0], "shape"): dshape = data[0].shape - - episode_group.create_dataset(key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape)) + + episode_group.create_dataset( + key, data=data, dtype=dtype, chunks=True, maxshape=(None, *dshape) + ) diff --git a/minari/storage/local.py b/minari/storage/local.py index 54213ac3..1f12ca66 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -77,10 +77,7 @@ def list_local_datasets( env_name, dataset_name, version = parse_dataset_id(dst_id) dataset = f"{env_name}-{dataset_name}" if latest_version: - if ( - dataset not in local_datasets - or version > local_datasets[dataset][0] - ): + if dataset not in local_datasets or version > local_datasets[dataset][0]: local_datasets[dataset] = (version, metadata) else: local_datasets[dst_id] = metadata diff --git a/minari/utils.py b/minari/utils.py index 4f32449f..78502562 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -202,9 +202,7 @@ def __call__(self, observation: ObsType) -> ActType: return self.action_space.sample() -def combine_datasets( - datasets_to_combine: List[MinariDataset], new_dataset_id: str -): +def combine_datasets(datasets_to_combine: List[MinariDataset], new_dataset_id: str): """Combine a group of MinariDataset in to a single dataset with its own name id. The new dataset will contain a metadata attribute `combined_datasets` containing a list @@ -231,20 +229,21 @@ def combine_datasets( new_dataset_path = get_dataset_path(new_dataset_id) new_dataset_path.mkdir() new_storage = MinariStorage.new( - new_dataset_path.joinpath("data"), - env_spec=combined_dataset_env_spec + new_dataset_path.joinpath("data"), env_spec=combined_dataset_env_spec ) - new_storage.update_metadata({ - "dataset_id": new_dataset_id, - "combined_datasets": [ - dataset.spec.dataset_id for dataset in datasets_to_combine - ], - "minari_version": str(minari_version_specifier) - }) + new_storage.update_metadata( + { + "dataset_id": new_dataset_id, + "combined_datasets": [ + dataset.spec.dataset_id for dataset in datasets_to_combine + ], + "minari_version": str(minari_version_specifier), + } + ) for dataset in datasets_to_combine: - new_storage.update_from_storage(dataset.storage) + new_storage.update_from_storage(dataset.storage) return MinariDataset(new_storage) @@ -407,15 +406,15 @@ def create_dataset_from_buffers( dataset_path = os.path.join(dataset_path, "data") storage = MinariStorage.new( - dataset_path, + dataset_path, observation_space=observation_space, action_space=action_space, - env_spec=env.spec + env_spec=env.spec, ) metadata: Dict[str, Any] = { "dataset_id": dataset_id, - "minari_version": minari_version + "minari_version": minari_version, } if algorithm_name is not None: metadata["algorithm_name"] = algorithm_name @@ -424,7 +423,7 @@ def create_dataset_from_buffers( if author_email is not None: metadata["author_email"] = author_email if code_permalink is not None: - metadata["code_permalink"] = code_permalink + metadata["code_permalink"] = code_permalink if expert_policy is not None or ref_max_score is not None: env = copy.deepcopy(env) if ref_min_score is None: @@ -528,7 +527,7 @@ def create_dataset_from_collector_env( raise ValueError( f"A Minari dataset with ID {dataset_id} already exists and it cannot be overridden. Please use a different dataset name or version." ) - + dataset_path = os.path.join(dataset_path, "data") os.makedirs(dataset_path) dataset_metadata: Dict[str, Any] = { @@ -542,7 +541,7 @@ def create_dataset_from_collector_env( if author_email is not None: dataset_metadata["author_email"] = author_email if code_permalink is not None: - dataset_metadata["code_permalink"] = code_permalink + dataset_metadata["code_permalink"] = code_permalink if expert_policy is not None or ref_max_score is not None: env = copy.deepcopy(collector_env.env) if ref_min_score is None: @@ -561,6 +560,7 @@ def create_dataset_from_collector_env( collector_env.save_to_disk(dataset_path, dataset_metadata) return MinariDataset(dataset_path) + def get_normalized_score( dataset: MinariDataset, returns: Union[float, np.float32] ) -> Union[float, np.float32]: diff --git a/tests/common.py b/tests/common.py index eca597f0..9a8856d6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -473,9 +473,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): observation_space, episode["total_timesteps"] + 1, ) - _check_space_elem( - episode["actions"], action_space, episode["total_timesteps"] - ) + _check_space_elem(episode["actions"], action_space, episode["total_timesteps"]) for i in range(episode["total_timesteps"] + 1): obs = _reconstuct_obs_or_action_at_index_recursive( @@ -582,6 +580,7 @@ def create_dummy_dataset_with_collecter_env_helper( assert dataset_id in minari.list_local_datasets() return dataset + def check_episode_data_integrity( episode_data_list: List[EpisodeData], observation_space: gym.spaces.Space, diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 516f062a..c31d311a 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -5,12 +5,11 @@ from typing import Any import gymnasium as gym -from gymnasium.envs.registration import EnvSpec import numpy as np import pytest +from gymnasium.envs.registration import EnvSpec import minari -from minari import __version__ from minari import DataCollectorV0, MinariDataset from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import MinariStorage @@ -466,19 +465,20 @@ def test_update_dataset_from_buffer(dataset_id, env_id): def test_missing_env_module(): - data_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets", "dummy-test-v0") - class FakeEnvSpec(EnvSpec): - def to_json(self) -> str: - return r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""" - - storage = MinariStorage.new( - data_path, - env_spec=FakeEnvSpec("DummyEnv-v0"), - ) - - with pytest.raises( - ModuleNotFoundError, match="No module named 'dummymodule'" - ): - MinariDataset(storage.data_path) - - shutil.rmtree(data_path) + data_path = os.path.join( + os.path.expanduser("~"), ".minari", "datasets", "dummy-test-v0" + ) + + class FakeEnvSpec(EnvSpec): + def to_json(self) -> str: + return r"""{"id": "DummyEnv-v0", "entry_point": "dummymodule:dummyenv", "reward_threshold": null, "nondeterministic": false, "max_episode_steps": 300, "order_enforce": true, "disable_env_checker": false, "apply_api_compatibility": false, "additional_wrappers": []}""" + + storage = MinariStorage.new( + data_path, + env_spec=FakeEnvSpec("DummyEnv-v0"), + ) + + with pytest.raises(ModuleNotFoundError, match="No module named 'dummymodule'"): + MinariDataset(storage.data_path) + + shutil.rmtree(data_path) diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index 311d0aef..5a1bdecc 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,8 +1,10 @@ import tempfile -from minari.dataset.minari_storage import MinariStorage -from gymnasium import spaces -import pytest + import numpy as np +import pytest +from gymnasium import spaces + +from minari.dataset.minari_storage import MinariStorage @pytest.fixture(autouse=True) @@ -12,7 +14,9 @@ def tmp_dir(): tmp_dir.cleanup() -def _generate_episode_dict(observation_space: spaces.Space, action_space: spaces.Space, length=25): +def _generate_episode_dict( + observation_space: spaces.Space, action_space: spaces.Space, length=25 +): terminations = np.zeros(length, dtype=np.bool_) truncations = np.zeros(length, dtype=np.bool_) terminated = np.random.randint(2, dtype=np.bool_) @@ -24,15 +28,16 @@ def _generate_episode_dict(observation_space: spaces.Space, action_space: spaces "actions": [action_space.sample() for _ in range(length)], "rewards": np.random.randn(length), "terminations": terminations, - "truncations": truncations + "truncations": truncations, } + def test_non_existing_data(tmp_dir): with pytest.raises(ValueError, match="The data path foo doesn't exist"): - MinariStorage("foo") - + MinariStorage("foo") + with pytest.raises(ValueError, match="No data found in data path"): - MinariStorage(tmp_dir) + MinariStorage(tmp_dir) def test_metadata(tmp_dir): @@ -41,28 +46,23 @@ def test_metadata(tmp_dir): storage = MinariStorage.new( data_path=tmp_dir, observation_space=observation_space, - action_space=action_space + action_space=action_space, ) assert storage.data_path == tmp_dir - extra_metadata = { - "float": 3.2, - "string": "test-value", - "int": 2, - "bool": True - } + extra_metadata = {"float": 3.2, "string": "test-value", "int": 2, "bool": True} storage.update_metadata(extra_metadata) storage_metadata = storage.metadata assert storage_metadata.keys() == { - 'action_space', - 'bool', - 'float', - 'int', - 'observation_space', - 'string', - 'total_episodes', - 'total_steps' + "action_space", + "bool", + "float", + "int", + "observation_space", + "string", + "total_episodes", + "total_steps", } for key, value in extra_metadata.items(): @@ -78,13 +78,15 @@ def test_add_episodes(tmp_dir): n_episodes = 10 steps_per_episode = 25 episodes = [ - _generate_episode_dict(observation_space, action_space, length=steps_per_episode) + _generate_episode_dict( + observation_space, action_space, length=steps_per_episode + ) for _ in range(n_episodes) ] storage = MinariStorage.new( data_path=tmp_dir, observation_space=observation_space, - action_space=action_space + action_space=action_space, ) storage.update_episodes(episodes) del storage @@ -95,7 +97,7 @@ def test_add_episodes(tmp_dir): for i, ep in enumerate(episodes): storage_ep = storage.get_episodes([i])[0] - + assert np.all(ep["observations"] == storage_ep["observations"]) assert np.all(ep["actions"] == storage_ep["actions"]) assert np.all(ep["rewards"] == storage_ep["rewards"]) @@ -140,20 +142,18 @@ def test_apply(tmp_dir): storage = MinariStorage.new( data_path=tmp_dir, observation_space=observation_space, - action_space=action_space + action_space=action_space, ) storage.update_episodes(episodes) def f(ep): return ep["actions"].sum() - + episode_indices = [1, 3, 5] outs = storage.apply(f, episode_indices=episode_indices) assert len(episode_indices) == len(list(outs)) for i, result in zip(episode_indices, outs): assert np.array(episodes[i]["actions"]).sum() == result - - def test_episode_metadata(tmp_dir): @@ -167,7 +167,7 @@ def test_episode_metadata(tmp_dir): storage = MinariStorage.new( data_path=tmp_dir, observation_space=observation_space, - action_space=action_space + action_space=action_space, ) storage.update_episodes(episodes) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index 63225528..1be5dc3c 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -71,6 +71,7 @@ def _generate_dataset_with_collector_env( if max_episode_steps is None: # Force None max_episode_steps env_spec = gym.make("CartPole-v1").spec + assert env_spec is not None env_spec.max_episode_steps = None env = env_spec.make() else: diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 57623713..479d0df4 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -166,7 +166,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_env_recovery(env, dataset) env.close() @@ -280,7 +280,7 @@ def test_generate_dataset_with_space_subset_external_buffer(): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_env_recovery_with_subset_spaces( env, dataset, action_space_subset, observation_space_subset ) From 037b08be8522c514ebb1e30f6aadece5c0a1fd5e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Oct 2023 00:04:03 -0400 Subject: [PATCH 15/19] restore tests --- tests/dataset/test_minari_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index c31d311a..5c901a40 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -62,9 +62,9 @@ def test_episode_data(space: gym.Space): @pytest.mark.parametrize( "dataset_id,env_id", [ - # ("cartpole-test-v0", "CartPole-v1"), - # ("dummy-dict-test-v0", "DummyDictEnv-v0"), - # ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-box-test-v0", "DummyBoxEnv-v0"), ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), ("dummy-combo-test-v0", "DummyComboEnv-v0"), ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), From 31b126aa7776a3f8014f0328982f8138f73003b6 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Oct 2023 23:34:38 -0400 Subject: [PATCH 16/19] refactor get_normalized_score --- minari/dataset/minari_storage.py | 2 +- minari/utils.py | 16 ++++-------- tests/common.py | 5 ++-- tests/utils/test_get_normalized_score.py | 32 ++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 31e02441..3ab16163 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -87,7 +87,7 @@ def new( return obj @property - def metadata(self) -> Dict: + def metadata(self) -> Dict[str, Any]: """Metadata of the dataset.""" metadata = {} with h5py.File(self._file_path, "r") as file: diff --git a/minari/utils.py b/minari/utils.py index 78502562..09b20681 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -7,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import gymnasium as gym -import h5py import numpy as np import portion as P from gymnasium.core import ActType, ObsType @@ -561,9 +560,7 @@ def create_dataset_from_collector_env( return MinariDataset(dataset_path) -def get_normalized_score( - dataset: MinariDataset, returns: Union[float, np.float32] -) -> Union[float, np.float32]: +def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray: r"""Normalize undiscounted return of an episode. This function was originally provided in the `D4RL repository `_. @@ -579,20 +576,17 @@ def get_normalized_score( Args: dataset (MinariDataset): the MinariDataset with respect to which normalize the score. Must contain the reference score attributes `ref_min_score` and `ref_max_score`. - returns (float | np.float32): a single value or array of episode undiscounted returns to normalize. + returns (np.ndarray): a single value or array of episode undiscounted returns to normalize. Returns: normalized_scores """ - with h5py.File(dataset.spec.data_path, "r") as f: - ref_min_score = f.attrs.get("ref_min_score", default=None) - ref_max_score = f.attrs.get("ref_max_score", default=None) + ref_min_score = dataset.storage.metadata.get("ref_min_score") + ref_max_score = dataset.storage.metadata.get("ref_max_score") + if ref_min_score is None or ref_max_score is None: raise ValueError( f"Reference score not provided for dataset {dataset.spec.dataset_id}. Can't compute the normalized score." ) - assert isinstance(ref_min_score, float) - assert isinstance(ref_max_score, float) - return (returns - ref_min_score) / (ref_max_score - ref_min_score) diff --git a/tests/common.py b/tests/common.py index 9a8856d6..9a595613 100644 --- a/tests/common.py +++ b/tests/common.py @@ -549,7 +549,7 @@ def check_load_and_delete_dataset(dataset_id: str): def create_dummy_dataset_with_collecter_env_helper( - dataset_id: str, env: DataCollectorV0, num_episodes: int = 10 + dataset_id: str, env: DataCollectorV0, num_episodes: int = 10, **kwargs ): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: @@ -572,9 +572,10 @@ def create_dummy_dataset_with_collecter_env_helper( dataset_id=dataset_id, collector_env=env, algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/main/tests/common.py", author="WillDudley", author_email="wdudley@farama.org", + **kwargs, ) assert dataset_id in minari.list_local_datasets() diff --git a/tests/utils/test_get_normalized_score.py b/tests/utils/test_get_normalized_score.py index e69de29b..751c9029 100644 --- a/tests/utils/test_get_normalized_score.py +++ b/tests/utils/test_get_normalized_score.py @@ -0,0 +1,32 @@ +import gymnasium as gym +import numpy as np + +import minari +from minari import get_normalized_score +from minari.data_collector.data_collector import DataCollectorV0 +from tests.common import create_dummy_dataset_with_collecter_env_helper + + +def test_ref_score(): + local_datasets = minari.list_local_datasets() + if "cartpole-test-v0" in local_datasets: + minari.delete_dataset("cartpole-test-v0") + + env = gym.make("CartPole-v1") + + env = DataCollectorV0(env) + num_episodes = 10 + + ref_min_score, ref_max_score = -1, 100 + dataset = create_dummy_dataset_with_collecter_env_helper( + "cartpole-test-v0", + env, + num_episodes=num_episodes, + ref_min_score=ref_min_score, + ref_max_score=ref_max_score, + ) + + scores = np.linspace(ref_min_score, ref_max_score, num=10) + norm_scores = np.linspace(0, 1, num=10) + + assert np.allclose(get_normalized_score(dataset, scores), norm_scores) From 47c4a68c387633bb328a6bee4753cd9d9f1bed9a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 12 Oct 2023 18:10:55 -0400 Subject: [PATCH 17/19] fix test after merge --- tests/utils/test_dataset_combine.py | 49 ---------------------------- tests/utils/test_dataset_creation.py | 28 +++++++--------- 2 files changed, 12 insertions(+), 65 deletions(-) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index 8987385f..1be5dc3c 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -1,7 +1,6 @@ from typing import Optional import gymnasium as gym -import h5py import pytest from gymnasium.utils.env_checker import data_equivalence from packaging.specifiers import SpecifierSet @@ -204,51 +203,3 @@ def test_combine_minari_version_specifiers(specifier_intersection, version_speci intersection = combine_minari_version_specifiers(version_specifiers) assert specifier_intersection == intersection - - -# in the future, if the logic of save metadata of combined dataset changes, this should be changed as well -def test_combine_dataset_with_different_metadata(): - n_data = 2 - dataset_list = [] - for i in range(n_data): - dataset_id = f"cartpole-test-{i}-v0" - env = gym.make("CartPole-v1", max_episode_steps=500) - env = DataCollectorV0(env) - env.reset(seed=42) - for episode in range(5): - terminated = False - truncated = False - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - _, _, terminated, truncated, _ = env.step(action) - env.reset() - - # Create Minari dataset and store locally - permalink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" - dataset = minari.create_dataset_from_collector_env( - dataset_id=dataset_id, - collector_env=env, - algorithm_name="random_policy" + str(i), - code_permalink=permalink + str(i), - author="WillDudley" + str(i), - author_email="wdudley@farama.org" + str(i), - ) - assert isinstance(dataset, MinariDataset) - env.close() - dataset_list.append(dataset) - - combined_dataset = combine_datasets( - dataset_list, new_dataset_id="cartpole-combined-test-v0" - ) - permalink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" - with h5py.File(combined_dataset.spec.data_path) as dt_file: - assert dt_file.attrs["algorithm_name"] == "random_policy" + str(n_data - 1) - _final_code_link = permalink + str(n_data - 1) - assert dt_file.attrs["code_permalink"] == _final_code_link - assert dt_file.attrs["author"] == "WillDudley" + str(n_data - 1) - assert dt_file.attrs["author_email"] == "wdudley@farama.org" + str(n_data - 1) - - for i in range(n_data): - minari.delete_dataset(f"cartpole-test-{i}-v0") - minari.delete_dataset("cartpole-combined-test-v0") - return diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 393789e4..a35a5a82 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -3,7 +3,6 @@ from typing import Dict import gymnasium as gym -import h5py import numpy as np import pytest from gymnasium import spaces @@ -70,14 +69,12 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): author_email="wdudley@farama.org", ) - # test metadata - - with h5py.File(dataset.spec.data_path, "r") as data_file: - assert data_file.attrs["algorithm_name"] == "random_policy" - codelink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" - assert data_file.attrs["code_permalink"] == codelink - assert data_file.attrs["author"] == "WillDudley" - assert data_file.attrs["author_email"] == "wdudley@farama.org" + metadata = dataset.storage.metadata + assert metadata["algorithm_name"] == "random_policy" + codelink = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" + assert metadata["code_permalink"] == codelink + assert metadata["author"] == "WillDudley" + assert metadata["author_email"] == "wdudley@farama.org" assert isinstance(dataset, MinariDataset) assert dataset.total_episodes == num_episodes @@ -287,13 +284,12 @@ def test_generate_dataset_with_space_subset_external_buffer(): observation_space=observation_space_subset, ) - # test metadata - with h5py.File(dataset.spec.data_path, "r") as data_file: - assert data_file.attrs["algorithm_name"] == "random_policy" - code_link = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" - assert data_file.attrs["code_permalink"] == code_link - assert data_file.attrs["author"] == "WillDudley" - assert data_file.attrs["author_email"] == "wdudley@farama.org" + metadata = dataset.storage.metadata + assert metadata["algorithm_name"] == "random_policy" + code_link = "https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py" + assert metadata["code_permalink"] == code_link + assert metadata["author"] == "WillDudley" + assert metadata["author_email"] == "wdudley@farama.org" assert isinstance(dataset, MinariDataset) assert dataset.total_episodes == num_episodes From c5c88e8a03921ac8012a632c741916560670acc2 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 19 Oct 2023 17:02:35 -0400 Subject: [PATCH 18/19] fix metadata on GCP --- minari/storage/hosting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index d26e9c84..2bf2aea2 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -5,6 +5,7 @@ import os import warnings from typing import Dict, List +import h5py from google.cloud import storage # pyright: ignore [reportGeneralTypeIssues] from gymnasium import logger @@ -43,7 +44,8 @@ def _upload_local_directory_to_gcs(local_path, bucket, gcs_path): blob = bucket.blob(remote_path) # add metadata to main data file of dataset if blob.name.endswith("main_data.hdf5"): - blob.metadata = metadata + with h5py.File(local_file, "r") as file: # TODO: remove h5py when migrating to JSON metadata + blob.metadata = file.attrs blob.upload_from_filename(local_file) file_path = get_dataset_path(dataset_id) From 707f138daf9641c3d697a7f0c88b1485e4d920f2 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Oct 2023 18:02:54 -0400 Subject: [PATCH 19/19] fix pre-commit --- minari/storage/hosting.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index 2bf2aea2..375c66f7 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -5,15 +5,14 @@ import os import warnings from typing import Dict, List -import h5py +import h5py from google.cloud import storage # pyright: ignore [reportGeneralTypeIssues] from gymnasium import logger from packaging.specifiers import SpecifierSet from tqdm.auto import tqdm # pyright: ignore [reportMissingModuleSource] from minari.dataset.minari_dataset import parse_dataset_id -from minari.dataset.minari_storage import MinariStorage from minari.storage.datasets_root_dir import get_dataset_path from minari.storage.local import load_dataset @@ -44,7 +43,9 @@ def _upload_local_directory_to_gcs(local_path, bucket, gcs_path): blob = bucket.blob(remote_path) # add metadata to main data file of dataset if blob.name.endswith("main_data.hdf5"): - with h5py.File(local_file, "r") as file: # TODO: remove h5py when migrating to JSON metadata + with h5py.File( + local_file, "r" + ) as file: # TODO: remove h5py when migrating to JSON metadata blob.metadata = file.attrs blob.upload_from_filename(local_file) @@ -58,8 +59,6 @@ def _upload_local_directory_to_gcs(local_path, bucket, gcs_path): dataset = load_dataset(dataset_id) - metadata = MinariStorage(dataset.spec.data_path).metadata - # See https://github.com/googleapis/python-storage/issues/27 for discussion on progress bars _upload_local_directory_to_gcs(str(file_path), bucket, dataset_id)