-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #274 from MinaKh/add_Aug_DiskCachedDataset
AugDiskCachedDataset to map the copy index to augmentation parameter
- Loading branch information
Showing
5 changed files
with
338 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using Aug_DiskCachedDataset for efficient caching of augmented copies\n", | ||
"- `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. \n", | ||
"\n", | ||
"- This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values.\n", | ||
"\n", | ||
"- Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues:\n", | ||
"\n", | ||
" - Copies might not be unique, as copy_index is not linked to the augmentation parameter \n", | ||
" - And there is no guarantee that copies cover the desired augmentation space\n", | ||
" \n", | ||
"\n", | ||
"\n", | ||
"- `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:\n", | ||
"\n", | ||
" - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). \n", | ||
" \n", | ||
" - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`)\n", | ||
"\n", | ||
"- Follwing is a simple example to show function of `Aug_DiskCachedDataset`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### A simple dataset " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# %%writefile mini_dataset.py\n", | ||
"import warnings\n", | ||
"warnings.filterwarnings('ignore')\n", | ||
"from torch.utils.data import Dataset\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"class mini_dataset(Dataset):\n", | ||
" def __init__(self) -> None:\n", | ||
" super().__init__()\n", | ||
" np.random.seed(0)\n", | ||
" self.data = np.random.rand(10, 16000)\n", | ||
" self.transform = None\n", | ||
" self.target_transform = None\n", | ||
"\n", | ||
" def __getitem__(self, index):\n", | ||
" sample = self.data[index]\n", | ||
" label = 1\n", | ||
" if sample.ndim==1:\n", | ||
" sample = sample[None,...]\n", | ||
" if self.transform is not None:\n", | ||
" sample = self.transform(sample)\n", | ||
" if self.target_transform is not None:\n", | ||
" label = self.target_transform(label) \n", | ||
"\n", | ||
" return sample, label " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Initializing `Aug_DiskCachedDataset` with transforms" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache\n", | ||
"from tonic.audio_transforms import AmplitudeScale, FixLength\n", | ||
"from tonic.audio_augmentations import RandomPitchShift\n", | ||
"\n", | ||
"all_transforms = {}\n", | ||
"all_transforms[\"pre_aug\"] = [AmplitudeScale(max_amplitude = 0.150)]\n", | ||
"all_transforms[\"augmentations\"] = [RandomPitchShift(samplerate=16000, caching=True)]\n", | ||
"all_transforms[\"post_aug\"] = [FixLength(16000)]\n", | ||
"\n", | ||
"# number of copies is set to number of augmentation params (factors)\n", | ||
"n = len(RandomPitchShift(samplerate=16000, caching=True).factors)\n", | ||
"Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Generating all copies of a data sample\n", | ||
" - 10 augmented versions of data sample with index = 0 are generated" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sample_index = 0\n", | ||
"Aug_cach.generate_all(sample_index)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### To verify\n", | ||
" - loading the saved copies \n", | ||
" - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter \n", | ||
" - they are equal\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n", | ||
"True\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from torchvision.transforms import Compose\n", | ||
"\n", | ||
"for i in range(n):\n", | ||
" transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)])\n", | ||
" ds = mini_dataset()\n", | ||
" ds.transform = transform\n", | ||
" sample = ds[sample_index][0]\n", | ||
" data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' )\n", | ||
" print((sample==data).all())\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "py_310", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ typing_extensions | |
librosa | ||
pbr | ||
expelliarmus | ||
typing-extensions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
|
||
import numpy as np | ||
|
||
from tonic.audio_augmentations import RandomPitchShift | ||
from tonic.audio_transforms import AmplitudeScale, FixLength | ||
from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache | ||
|
||
|
||
class mini_dataset: | ||
def __init__(self) -> None: | ||
np.random.seed(0) | ||
self.data = np.random.rand(10, 16000) | ||
self.transform = None | ||
self.target_transform = None | ||
|
||
def __getitem__(self, index): | ||
sample = self.data[index] | ||
label = 1 | ||
if sample.ndim == 1: | ||
sample = sample[None, ...] | ||
if self.transform is not None: | ||
sample = self.transform(sample) | ||
if self.target_transform is not None: | ||
label = self.target_transform(label) | ||
|
||
return sample, label | ||
|
||
|
||
def test_aug_disk_caching(): | ||
from torchvision.transforms import Compose | ||
|
||
all_transforms = {} | ||
all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude=0.150)] | ||
all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)] | ||
all_transforms["post_aug"] = [FixLength(16000)] | ||
# number of copies is set to number of augmentation params (factors) | ||
n = len(RandomPitchShift(samplerate=16000, caching=True).factors) | ||
Aug_cach = Aug_DiskCachedDataset( | ||
dataset=mini_dataset(), | ||
cache_path="cache/", | ||
all_transforms=all_transforms, | ||
num_copies=n, | ||
) | ||
|
||
if not os.path.isdir("cache/"): | ||
os.mkdir("cache/") | ||
|
||
sample_index = 0 | ||
Aug_cach.generate_all(sample_index) | ||
|
||
for i in range(n): | ||
transform = Compose( | ||
[ | ||
AmplitudeScale(max_amplitude=0.150), | ||
RandomPitchShift(samplerate=16000, caching=True, aug_index=i), | ||
FixLength(16000), | ||
] | ||
) | ||
ds = mini_dataset() | ||
ds.transform = transform | ||
augmented_sample = ds[sample_index][0] | ||
loaded_sample, targets = load_from_disk_cache( | ||
"cache/" + "0_" + str(i) + ".hdf5" | ||
) | ||
assert (augmented_sample == loaded_sample).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters