diff --git a/docs/tutorials/Aug_DiskCachDataset.ipynb b/docs/tutorials/Aug_DiskCachDataset.ipynb new file mode 100644 index 0000000..5a203d9 --- /dev/null +++ b/docs/tutorials/Aug_DiskCachDataset.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using Aug_DiskCachedDataset for efficient caching of augmented copies\n", + "- `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. \n", + "\n", + "- This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values.\n", + "\n", + "- Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues:\n", + "\n", + " - Copies might not be unique, as copy_index is not linked to the augmentation parameter \n", + " - And there is no guarantee that copies cover the desired augmentation space\n", + " \n", + "\n", + "\n", + "- `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:\n", + "\n", + " - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). \n", + " \n", + " - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`)\n", + "\n", + "- Follwing is a simple example to show function of `Aug_DiskCachedDataset`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A simple dataset " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# %%writefile mini_dataset.py\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "from torch.utils.data import Dataset\n", + "import numpy as np\n", + "\n", + "class mini_dataset(Dataset):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " np.random.seed(0)\n", + " self.data = np.random.rand(10, 16000)\n", + " self.transform = None\n", + " self.target_transform = None\n", + "\n", + " def __getitem__(self, index):\n", + " sample = self.data[index]\n", + " label = 1\n", + " if sample.ndim==1:\n", + " sample = sample[None,...]\n", + " if self.transform is not None:\n", + " sample = self.transform(sample)\n", + " if self.target_transform is not None:\n", + " label = self.target_transform(label) \n", + "\n", + " return sample, label " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initializing `Aug_DiskCachedDataset` with transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache\n", + "from tonic.audio_transforms import AmplitudeScale, FixLength\n", + "from tonic.audio_augmentations import RandomPitchShift\n", + "\n", + "all_transforms = {}\n", + "all_transforms[\"pre_aug\"] = [AmplitudeScale(max_amplitude = 0.150)]\n", + "all_transforms[\"augmentations\"] = [RandomPitchShift(samplerate=16000, caching=True)]\n", + "all_transforms[\"post_aug\"] = [FixLength(16000)]\n", + "\n", + "# number of copies is set to number of augmentation params (factors)\n", + "n = len(RandomPitchShift(samplerate=16000, caching=True).factors)\n", + "Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating all copies of a data sample\n", + " - 10 augmented versions of data sample with index = 0 are generated" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sample_index = 0\n", + "Aug_cach.generate_all(sample_index)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### To verify\n", + " - loading the saved copies \n", + " - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter \n", + " - they are equal\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n" + ] + } + ], + "source": [ + "from torchvision.transforms import Compose\n", + "\n", + "for i in range(n):\n", + " transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)])\n", + " ds = mini_dataset()\n", + " ds.transform = transform\n", + " sample = ds[sample_index][0]\n", + " data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' )\n", + " print((sample==data).all())\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py_310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 6dcf922..db9ece2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ typing_extensions librosa pbr expelliarmus +typing-extensions diff --git a/test/test_aug_caching.py b/test/test_aug_caching.py new file mode 100644 index 0000000..ab4dfeb --- /dev/null +++ b/test/test_aug_caching.py @@ -0,0 +1,66 @@ +import os + +import numpy as np + +from tonic.audio_augmentations import RandomPitchShift +from tonic.audio_transforms import AmplitudeScale, FixLength +from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache + + +class mini_dataset: + def __init__(self) -> None: + np.random.seed(0) + self.data = np.random.rand(10, 16000) + self.transform = None + self.target_transform = None + + def __getitem__(self, index): + sample = self.data[index] + label = 1 + if sample.ndim == 1: + sample = sample[None, ...] + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + label = self.target_transform(label) + + return sample, label + + +def test_aug_disk_caching(): + from torchvision.transforms import Compose + + all_transforms = {} + all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude=0.150)] + all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)] + all_transforms["post_aug"] = [FixLength(16000)] + # number of copies is set to number of augmentation params (factors) + n = len(RandomPitchShift(samplerate=16000, caching=True).factors) + Aug_cach = Aug_DiskCachedDataset( + dataset=mini_dataset(), + cache_path="cache/", + all_transforms=all_transforms, + num_copies=n, + ) + + if not os.path.isdir("cache/"): + os.mkdir("cache/") + + sample_index = 0 + Aug_cach.generate_all(sample_index) + + for i in range(n): + transform = Compose( + [ + AmplitudeScale(max_amplitude=0.150), + RandomPitchShift(samplerate=16000, caching=True, aug_index=i), + FixLength(16000), + ] + ) + ds = mini_dataset() + ds.transform = transform + augmented_sample = ds[sample_index][0] + loaded_sample, targets = load_from_disk_cache( + "cache/" + "0_" + str(i) + ".hdf5" + ) + assert (augmented_sample == loaded_sample).all() diff --git a/tonic/__init__.py b/tonic/__init__.py index 26a0fa7..222eb35 100644 --- a/tonic/__init__.py +++ b/tonic/__init__.py @@ -1,7 +1,12 @@ from pbr.version import VersionInfo from . import collation, datasets, io, slicers, transforms, utils -from .cached_dataset import CachedDataset, DiskCachedDataset, MemoryCachedDataset +from .cached_dataset import ( + Aug_DiskCachedDataset, + CachedDataset, + DiskCachedDataset, + MemoryCachedDataset, +) from .dataset import Dataset from .sliced_dataset import SlicedDataset diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index 69c2ef7..6136929 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -1,9 +1,17 @@ import logging import os +import sys + +if sys.version_info >= (3, 8): + from typing import Callable, Iterable, Optional, Tuple, TypedDict, Union +else: + from typing import Callable, Iterable, Optional, Tuple, Union + from typing_extensions import TypedDict + +import random import shutil from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Iterable, Optional, Tuple, Union from warnings import warn import h5py @@ -227,6 +235,83 @@ def load_from_disk_cache(file_path: Union[str, Path]) -> Tuple: return data_list, target_list +@dataclass +class Aug_DiskCachedDataset(DiskCachedDataset): + """Aug_DiskCachedDataset is a child class from DiskCachedDataset with further customizations to + handle augmented copies of a sample. The goal of this customization is to map the indices of + cached files (copy) to augmentation parameters. This is useful in a category of augmentations + where the range of parameter is rather disceret and non probabilistic, for instance an audio + sample is being augmented with noise and SNR can take only N=5 values. Passing copy_index to + augmentation Class as an init argument ensures that each copy will be a a distinct augmented + sample with a trackable parameter. + + 'generate_all' method generates all augmented vesions of a sample. + 'generate_copy' method generates the missing variant (augmented version) + + + + Therefore all transforms applied to the dataset are categorized by the keys: "pre_aug", "augmentations" + and "post_aug". + + Args: + 'all_transforms' is a dictionarty passed to this class containing information about all transforms. + """ + + all_transforms: Optional[TypedDict] = None + + def __post_init__(self): + self.pre_aug = self.all_transforms["pre_aug"] + self.aug = self.all_transforms["augmentations"] + self.post_aug = self.all_transforms["post_aug"] + + def generate_all(self, item): + for copy in range(0, self.num_copies): + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + try: + data, targets = load_from_disk_cache(file_path) + except (FileNotFoundError, OSError) as _: + self.generate_copy(item, copy) + + def generate_copy(self, item, copy): + from torchvision.transforms import Compose + + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + # copy index is passed to augmentation (callable) + self.aug[0].aug_index = copy + augmentation = self.aug + self.dataset.transform = Compose(self.pre_aug + augmentation + self.post_aug) + data, targets = self.dataset[item] + save_to_disk_cache(data, targets, file_path=file_path, compress=self.compress) + + def __getitem__(self, item) -> Tuple[object, object]: + if self.dataset is None and item >= self.n_samples: + raise IndexError(f"This dataset only has {self.n_samples} items.") + + copy = random.randint(0, self.num_copies - 1) + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + try: + data, targets = load_from_disk_cache(file_path) + + except (FileNotFoundError, OSError) as _: + logging.info( + f"Data {item}: {file_path} not in cache, generating it now", + stacklevel=2, + ) + self.generate_copy(item, copy) + + # format might change during save to hdf5, i.e. tensors -> np arrays + # We load the sample here again to keep the output format consistent. + data, targets = load_from_disk_cache(file_path) + + if self.transform is not None: + data = self.transform(data) + if self.target_transform is not None: + targets = self.target_transform(targets) + if self.transforms is not None: + data, targets = self.transforms(data, targets) + return data, targets + + class CachedDataset(DiskCachedDataset): """Deprecated class that points to DiskCachedDataset for now but will be removed in a future release.