From 61c7f8579ac4edcc439f1e9327204875500e3b15 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Fri, 24 Nov 2023 15:36:02 +0100 Subject: [PATCH 01/13] AugDiskCachedDataset added --- tonic/cached_dataset.py | 83 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index 69c2ef79..e6ada314 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -1,13 +1,15 @@ import logging import os +import random import shutil from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import Callable, Iterable, Optional, Tuple, TypedDict, Union from warnings import warn import h5py import numpy as np +from torchvision.transforms import Compose @dataclass @@ -227,6 +229,85 @@ def load_from_disk_cache(file_path: Union[str, Path]) -> Tuple: return data_list, target_list +@dataclass +class Aug_DiskCachedDataset(DiskCachedDataset): + """Aug_DiskCachedDataset is a child class from DiskCachedDataset with further customizations to + handle augmented copies of a sample. The goal of this customization is to map the indices of + cached files (copy) to augmentation parameters. This is useful in a category of augmentations + where the range of parameter is rather disceret and non probabilistic, for instance an ausio + sample is being augmented with noise and SNR can take only N=5 values. Passing copy_index to + augmentation Class as an init argument ensures that each copy will be a a distinct augmented + sample with a trackable parameter. + + 'generate_all' method generates all augmented vesions of a sample. + 'generate_copy' method generates the missing variant (augmented version) + + + + Therefore all transforms applied to the dataset are categorized by the keys: "detrministic", "augmentations" + and "to_spike_generation". + + Args: + 'all_transforms' is a dictionarty passed to this class containing information about augmentations. + """ + + all_transforms: Optional[TypedDict] = None + + def __post_init__(self): + super().__init__() + self.deterministic_transform = self.all_transforms["detrministic"] + self.cached_aug = self.all_transforms["augmentations"] + self.to_spike_transform = self.all_transforms["to_spike_generation"] + + def generate_all(self, item): + for copy in range(0, self.num_copies): + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + try: + data, targets = load_from_disk_cache(file_path) + except (FileNotFoundError, OSError) as _: + self.generate_copy(item, copy) + + def generate_copy(self, item, copy): + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + # copy index is passed to augmentation (callable) + self.cached_aug.aug_index = copy + augmentation = [self.cached_aug] + self.dataset.transform = Compose( + self.deterministic_transform + augmentation + self.to_spike_transform + ) + data, targets = self.dataset[item] + save_to_disk_cache(data, targets, file_path=file_path, compress=self.compress) + + def __getitem__(self, item) -> Tuple[object, object]: + if self.dataset is None and item >= self.n_samples: + raise IndexError(f"This dataset only has {self.n_samples} items.") + + copy = random.randint(0, self.num_copies - 1) + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") + try: + data, targets = load_from_disk_cache(file_path) + + except (FileNotFoundError, OSError) as _: + logging.info( + f"Data {item}: {file_path} not in cache, generating it now", + stacklevel=2, + ) + # self.generate_all(item) + self.generate_copy(item, copy) + + # format might change during save to hdf5, i.e. tensors -> np arrays + # We load the sample here again to keep the output format consistent. + data, targets = load_from_disk_cache(file_path) + + if self.transform is not None: + data = self.transform(data) + if self.target_transform is not None: + targets = self.target_transform(targets) + if self.transforms is not None: + data, targets = self.transforms(data, targets) + return data, targets + + class CachedDataset(DiskCachedDataset): """Deprecated class that points to DiskCachedDataset for now but will be removed in a future release. From 9d00ca83a70748df523d02610580f50ba4f305a4 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Mon, 4 Dec 2023 11:50:45 +0100 Subject: [PATCH 02/13] dict keys updated to more generic --- tonic/cached_dataset.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index e6ada314..a032fef0 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -244,8 +244,8 @@ class Aug_DiskCachedDataset(DiskCachedDataset): - Therefore all transforms applied to the dataset are categorized by the keys: "detrministic", "augmentations" - and "to_spike_generation". + Therefore all transforms applied to the dataset are categorized by the keys: "pre_aug", "augmentations" + and "post_aug". Args: 'all_transforms' is a dictionarty passed to this class containing information about augmentations. @@ -255,9 +255,9 @@ class Aug_DiskCachedDataset(DiskCachedDataset): def __post_init__(self): super().__init__() - self.deterministic_transform = self.all_transforms["detrministic"] - self.cached_aug = self.all_transforms["augmentations"] - self.to_spike_transform = self.all_transforms["to_spike_generation"] + self.pre_aug = self.all_transforms["pre_aug"] + self.aug = self.all_transforms["augmentations"] + self.post_aug = self.all_transforms["post_aug"] def generate_all(self, item): for copy in range(0, self.num_copies): @@ -270,11 +270,9 @@ def generate_all(self, item): def generate_copy(self, item, copy): file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") # copy index is passed to augmentation (callable) - self.cached_aug.aug_index = copy - augmentation = [self.cached_aug] - self.dataset.transform = Compose( - self.deterministic_transform + augmentation + self.to_spike_transform - ) + self.aug.aug_index = copy + augmentation = [self.aug] + self.dataset.transform = Compose(self.pre_aug + augmentation + self.post_aug) data, targets = self.dataset[item] save_to_disk_cache(data, targets, file_path=file_path, compress=self.compress) From fb64e657fb5ed150b88981ec2fffa231a6cd1ee8 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Tue, 5 Dec 2023 15:30:34 +0100 Subject: [PATCH 03/13] small fixes and cleanup --- tonic/cached_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index a032fef0..e6c52cfa 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -234,7 +234,7 @@ class Aug_DiskCachedDataset(DiskCachedDataset): """Aug_DiskCachedDataset is a child class from DiskCachedDataset with further customizations to handle augmented copies of a sample. The goal of this customization is to map the indices of cached files (copy) to augmentation parameters. This is useful in a category of augmentations - where the range of parameter is rather disceret and non probabilistic, for instance an ausio + where the range of parameter is rather disceret and non probabilistic, for instance an audio sample is being augmented with noise and SNR can take only N=5 values. Passing copy_index to augmentation Class as an init argument ensures that each copy will be a a distinct augmented sample with a trackable parameter. @@ -248,7 +248,7 @@ class Aug_DiskCachedDataset(DiskCachedDataset): and "post_aug". Args: - 'all_transforms' is a dictionarty passed to this class containing information about augmentations. + 'all_transforms' is a dictionarty passed to this class containing information about all transforms. """ all_transforms: Optional[TypedDict] = None @@ -290,7 +290,6 @@ def __getitem__(self, item) -> Tuple[object, object]: f"Data {item}: {file_path} not in cache, generating it now", stacklevel=2, ) - # self.generate_all(item) self.generate_copy(item, copy) # format might change during save to hdf5, i.e. tensors -> np arrays From ccafa5a68a747264e06706e8bc45c00fe7cecd06 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Tue, 5 Dec 2023 16:03:13 +0100 Subject: [PATCH 04/13] adding typing-extensions to requirments --- test/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/requirements.txt b/test/requirements.txt index 33a0ea37..09105c3a 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -6,3 +6,4 @@ hdf5plugin imageio torchdata aedat +typing-extensions \ No newline at end of file From d2c995e0acf1d2a9a539022f17c5519e0a1d89cb Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Tue, 5 Dec 2023 16:08:25 +0100 Subject: [PATCH 05/13] importing TypedDict from typing_extensions for python 3.7 --- tonic/cached_dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index e6c52cfa..5ca24951 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -1,10 +1,17 @@ import logging import os +import sys + +if sys.version_info >= (3, 8): + from typing import Callable, Iterable, Optional, Tuple, TypedDict, Union +else: + from typing import Callable, Iterable, Optional, Tuple, Union + from typing_extensions import TypedDict + import random import shutil from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Iterable, Optional, Tuple, TypedDict, Union from warnings import warn import h5py From 170f31bd418dcb61a894d2733b09d452714c2da9 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Wed, 20 Dec 2023 18:34:32 +0100 Subject: [PATCH 06/13] bug fixed in Aug_DiskCach --- tonic/cached_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index 5ca24951..55e09e42 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -261,7 +261,6 @@ class Aug_DiskCachedDataset(DiskCachedDataset): all_transforms: Optional[TypedDict] = None def __post_init__(self): - super().__init__() self.pre_aug = self.all_transforms["pre_aug"] self.aug = self.all_transforms["augmentations"] self.post_aug = self.all_transforms["post_aug"] @@ -277,8 +276,8 @@ def generate_all(self, item): def generate_copy(self, item, copy): file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") # copy index is passed to augmentation (callable) - self.aug.aug_index = copy - augmentation = [self.aug] + self.aug[0].aug_index = copy + augmentation = self.aug self.dataset.transform = Compose(self.pre_aug + augmentation + self.post_aug) data, targets = self.dataset[item] save_to_disk_cache(data, targets, file_path=file_path, compress=self.compress) From 5c0f9a77af5d24a40eb3f08b9286be15c89c69f1 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Wed, 20 Dec 2023 19:09:13 +0100 Subject: [PATCH 07/13] including Aug_DiskCachedDataset in init imports --- tonic/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tonic/__init__.py b/tonic/__init__.py index 26a0fa7b..222eb355 100644 --- a/tonic/__init__.py +++ b/tonic/__init__.py @@ -1,7 +1,12 @@ from pbr.version import VersionInfo from . import collation, datasets, io, slicers, transforms, utils -from .cached_dataset import CachedDataset, DiskCachedDataset, MemoryCachedDataset +from .cached_dataset import ( + Aug_DiskCachedDataset, + CachedDataset, + DiskCachedDataset, + MemoryCachedDataset, +) from .dataset import Dataset from .sliced_dataset import SlicedDataset From b3c5d1b9262530122014ebac7c840ba9b9d6b6e8 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Wed, 20 Dec 2023 19:10:11 +0100 Subject: [PATCH 08/13] test added for aug_diskcached --- test/test_aug_caching.py | 67 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/test_aug_caching.py diff --git a/test/test_aug_caching.py b/test/test_aug_caching.py new file mode 100644 index 00000000..8f120dbf --- /dev/null +++ b/test/test_aug_caching.py @@ -0,0 +1,67 @@ +import os + +import numpy as np +from torch.utils.data import Dataset +from torchvision.transforms import Compose + +from tonic.audio_augmentations import RandomPitchShift +from tonic.audio_transforms import AmplitudeScale, FixLength +from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache + + +class mini_dataset(Dataset): + def __init__(self) -> None: + super().__init__() + np.random.seed(0) + self.data = np.random.rand(10, 16000) + self.transform = None + self.target_transform = None + + def __getitem__(self, index): + sample = self.data[index] + label = 1 + if sample.ndim == 1: + sample = sample[None, ...] + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + label = self.target_transform(label) + + return sample, label + + +def test_aug_disk_caching(): + all_transforms = {} + all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude=0.150)] + all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)] + all_transforms["post_aug"] = [FixLength(16000)] + # number of copies is set to number of augmentation params (factors) + n = len(RandomPitchShift(samplerate=16000, caching=True).factors) + Aug_cach = Aug_DiskCachedDataset( + dataset=mini_dataset(), + cache_path="cache/", + all_transforms=all_transforms, + num_copies=n, + ) + + if not os.path.isdir("cache/"): + os.mkdir("cache/") + + sample_index = 0 + Aug_cach.generate_all(sample_index) + + for i in range(n): + transform = Compose( + [ + AmplitudeScale(max_amplitude=0.150), + RandomPitchShift(samplerate=16000, caching=True, aug_index=i), + FixLength(16000), + ] + ) + ds = mini_dataset() + ds.transform = transform + augmented_sample = ds[sample_index][0] + loaded_sample, targets = load_from_disk_cache( + "cache/" + "0_" + str(i) + ".hdf5" + ) + assert (augmented_sample == loaded_sample).all() From c9d26b034e0a6839be7f4825ff5858b7ee2cc233 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Wed, 20 Dec 2023 19:11:12 +0100 Subject: [PATCH 09/13] notebook added to elaborate teh function of AugDiskCachedDataset --- docs/tutorials/Aug_DiskCachDataset.ipynb | 179 +++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 docs/tutorials/Aug_DiskCachDataset.ipynb diff --git a/docs/tutorials/Aug_DiskCachDataset.ipynb b/docs/tutorials/Aug_DiskCachDataset.ipynb new file mode 100644 index 00000000..5a203d93 --- /dev/null +++ b/docs/tutorials/Aug_DiskCachDataset.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using Aug_DiskCachedDataset for efficient caching of augmented copies\n", + "- `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. \n", + "\n", + "- This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values.\n", + "\n", + "- Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues:\n", + "\n", + " - Copies might not be unique, as copy_index is not linked to the augmentation parameter \n", + " - And there is no guarantee that copies cover the desired augmentation space\n", + " \n", + "\n", + "\n", + "- `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:\n", + "\n", + " - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). \n", + " \n", + " - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`)\n", + "\n", + "- Follwing is a simple example to show function of `Aug_DiskCachedDataset`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A simple dataset " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# %%writefile mini_dataset.py\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "from torch.utils.data import Dataset\n", + "import numpy as np\n", + "\n", + "class mini_dataset(Dataset):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " np.random.seed(0)\n", + " self.data = np.random.rand(10, 16000)\n", + " self.transform = None\n", + " self.target_transform = None\n", + "\n", + " def __getitem__(self, index):\n", + " sample = self.data[index]\n", + " label = 1\n", + " if sample.ndim==1:\n", + " sample = sample[None,...]\n", + " if self.transform is not None:\n", + " sample = self.transform(sample)\n", + " if self.target_transform is not None:\n", + " label = self.target_transform(label) \n", + "\n", + " return sample, label " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initializing `Aug_DiskCachedDataset` with transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache\n", + "from tonic.audio_transforms import AmplitudeScale, FixLength\n", + "from tonic.audio_augmentations import RandomPitchShift\n", + "\n", + "all_transforms = {}\n", + "all_transforms[\"pre_aug\"] = [AmplitudeScale(max_amplitude = 0.150)]\n", + "all_transforms[\"augmentations\"] = [RandomPitchShift(samplerate=16000, caching=True)]\n", + "all_transforms[\"post_aug\"] = [FixLength(16000)]\n", + "\n", + "# number of copies is set to number of augmentation params (factors)\n", + "n = len(RandomPitchShift(samplerate=16000, caching=True).factors)\n", + "Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating all copies of a data sample\n", + " - 10 augmented versions of data sample with index = 0 are generated" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sample_index = 0\n", + "Aug_cach.generate_all(sample_index)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### To verify\n", + " - loading the saved copies \n", + " - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter \n", + " - they are equal\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "True\n" + ] + } + ], + "source": [ + "from torchvision.transforms import Compose\n", + "\n", + "for i in range(n):\n", + " transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)])\n", + " ds = mini_dataset()\n", + " ds.transform = transform\n", + " sample = ds[sample_index][0]\n", + " data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' )\n", + " print((sample==data).all())\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py_310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 2036b8fc29aa419682788e67dba29bceccc16b24 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Thu, 23 May 2024 12:25:02 +0200 Subject: [PATCH 10/13] moving torch import to inside function, where it is used --- tonic/cached_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tonic/cached_dataset.py b/tonic/cached_dataset.py index 55e09e42..6136929a 100644 --- a/tonic/cached_dataset.py +++ b/tonic/cached_dataset.py @@ -16,7 +16,6 @@ import h5py import numpy as np -from torchvision.transforms import Compose @dataclass @@ -274,6 +273,8 @@ def generate_all(self, item): self.generate_copy(item, copy) def generate_copy(self, item, copy): + from torchvision.transforms import Compose + file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") # copy index is passed to augmentation (callable) self.aug[0].aug_index = copy From 0ef6582f10733f90ac8f2273108b25abef54266e Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Thu, 23 May 2024 12:26:33 +0200 Subject: [PATCH 11/13] moving typing-extensions to the root requirement --- requirements.txt | 1 + test/requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6dcf922c..db9ece24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ typing_extensions librosa pbr expelliarmus +typing-extensions diff --git a/test/requirements.txt b/test/requirements.txt index c24f4d2e..d799e198 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -4,4 +4,3 @@ matplotlib hdf5plugin imageio aedat -typing-extensions \ No newline at end of file From dd2baecce22a1be62139511d6a429ed6f9cc327b Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Thu, 30 May 2024 11:42:45 +0200 Subject: [PATCH 12/13] moving torchvision import to inside function --- test/test_aug_caching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_aug_caching.py b/test/test_aug_caching.py index 8f120dbf..efb8165e 100644 --- a/test/test_aug_caching.py +++ b/test/test_aug_caching.py @@ -2,7 +2,6 @@ import numpy as np from torch.utils.data import Dataset -from torchvision.transforms import Compose from tonic.audio_augmentations import RandomPitchShift from tonic.audio_transforms import AmplitudeScale, FixLength @@ -31,6 +30,8 @@ def __getitem__(self, index): def test_aug_disk_caching(): + from torchvision.transforms import Compose + all_transforms = {} all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude=0.150)] all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)] From c799c41035c107d4c0ae303419363e75106414f4 Mon Sep 17 00:00:00 2001 From: Mina Khoei Date: Thu, 30 May 2024 11:51:01 +0200 Subject: [PATCH 13/13] DataSet class import was removed, the mini dataset does not inherent from that --- test/test_aug_caching.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_aug_caching.py b/test/test_aug_caching.py index efb8165e..ab4dfebb 100644 --- a/test/test_aug_caching.py +++ b/test/test_aug_caching.py @@ -1,16 +1,14 @@ import os import numpy as np -from torch.utils.data import Dataset from tonic.audio_augmentations import RandomPitchShift from tonic.audio_transforms import AmplitudeScale, FixLength from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache -class mini_dataset(Dataset): +class mini_dataset: def __init__(self) -> None: - super().__init__() np.random.seed(0) self.data = np.random.rand(10, 16000) self.transform = None