Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AugDiskCachedDataset to map the copy index to augmentation parameter #274

Merged
merged 17 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions docs/tutorials/Aug_DiskCachDataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Aug_DiskCachedDataset for efficient caching of augmented copies\n",
"- `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. \n",
"\n",
"- This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values.\n",
"\n",
"- Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues:\n",
"\n",
" - Copies might not be unique, as copy_index is not linked to the augmentation parameter \n",
" - And there is no guarantee that copies cover the desired augmentation space\n",
" \n",
"\n",
"\n",
"- `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:\n",
"\n",
" - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). \n",
" \n",
" - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`)\n",
"\n",
"- Follwing is a simple example to show function of `Aug_DiskCachedDataset`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### A simple dataset "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# %%writefile mini_dataset.py\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"from torch.utils.data import Dataset\n",
"import numpy as np\n",
"\n",
"class mini_dataset(Dataset):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" np.random.seed(0)\n",
" self.data = np.random.rand(10, 16000)\n",
" self.transform = None\n",
" self.target_transform = None\n",
"\n",
" def __getitem__(self, index):\n",
" sample = self.data[index]\n",
" label = 1\n",
" if sample.ndim==1:\n",
" sample = sample[None,...]\n",
" if self.transform is not None:\n",
" sample = self.transform(sample)\n",
" if self.target_transform is not None:\n",
" label = self.target_transform(label) \n",
"\n",
" return sample, label "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initializing `Aug_DiskCachedDataset` with transforms"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache\n",
"from tonic.audio_transforms import AmplitudeScale, FixLength\n",
"from tonic.audio_augmentations import RandomPitchShift\n",
"\n",
"all_transforms = {}\n",
"all_transforms[\"pre_aug\"] = [AmplitudeScale(max_amplitude = 0.150)]\n",
"all_transforms[\"augmentations\"] = [RandomPitchShift(samplerate=16000, caching=True)]\n",
"all_transforms[\"post_aug\"] = [FixLength(16000)]\n",
"\n",
"# number of copies is set to number of augmentation params (factors)\n",
"n = len(RandomPitchShift(samplerate=16000, caching=True).factors)\n",
"Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating all copies of a data sample\n",
" - 10 augmented versions of data sample with index = 0 are generated"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sample_index = 0\n",
"Aug_cach.generate_all(sample_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### To verify\n",
" - loading the saved copies \n",
" - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter \n",
" - they are equal\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"from torchvision.transforms import Compose\n",
"\n",
"for i in range(n):\n",
" transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)])\n",
" ds = mini_dataset()\n",
" ds.transform = transform\n",
" sample = ds[sample_index][0]\n",
" data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' )\n",
" print((sample==data).all())\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py_310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ typing_extensions
librosa
pbr
expelliarmus
typing-extensions
67 changes: 67 additions & 0 deletions test/test_aug_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

import numpy as np
from torch.utils.data import Dataset
biphasic marked this conversation as resolved.
Show resolved Hide resolved
from torchvision.transforms import Compose
biphasic marked this conversation as resolved.
Show resolved Hide resolved

from tonic.audio_augmentations import RandomPitchShift
from tonic.audio_transforms import AmplitudeScale, FixLength
from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache


class mini_dataset(Dataset):
def __init__(self) -> None:
super().__init__()
np.random.seed(0)
self.data = np.random.rand(10, 16000)
self.transform = None
self.target_transform = None

def __getitem__(self, index):
sample = self.data[index]
label = 1
if sample.ndim == 1:
sample = sample[None, ...]
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
label = self.target_transform(label)

return sample, label


def test_aug_disk_caching():
all_transforms = {}
all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude=0.150)]
all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)]
all_transforms["post_aug"] = [FixLength(16000)]
# number of copies is set to number of augmentation params (factors)
n = len(RandomPitchShift(samplerate=16000, caching=True).factors)
Aug_cach = Aug_DiskCachedDataset(
dataset=mini_dataset(),
cache_path="cache/",
all_transforms=all_transforms,
num_copies=n,
)

if not os.path.isdir("cache/"):
os.mkdir("cache/")

sample_index = 0
Aug_cach.generate_all(sample_index)

for i in range(n):
transform = Compose(
[
AmplitudeScale(max_amplitude=0.150),
RandomPitchShift(samplerate=16000, caching=True, aug_index=i),
FixLength(16000),
]
)
ds = mini_dataset()
ds.transform = transform
augmented_sample = ds[sample_index][0]
loaded_sample, targets = load_from_disk_cache(
"cache/" + "0_" + str(i) + ".hdf5"
)
assert (augmented_sample == loaded_sample).all()
7 changes: 6 additions & 1 deletion tonic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from pbr.version import VersionInfo

from . import collation, datasets, io, slicers, transforms, utils
from .cached_dataset import CachedDataset, DiskCachedDataset, MemoryCachedDataset
from .cached_dataset import (
Aug_DiskCachedDataset,
CachedDataset,
DiskCachedDataset,
MemoryCachedDataset,
)
from .dataset import Dataset
from .sliced_dataset import SlicedDataset

Expand Down
87 changes: 86 additions & 1 deletion tonic/cached_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import logging
import os
import sys

if sys.version_info >= (3, 8):
from typing import Callable, Iterable, Optional, Tuple, TypedDict, Union
else:
from typing import Callable, Iterable, Optional, Tuple, Union
from typing_extensions import TypedDict

import random
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Iterable, Optional, Tuple, Union
from warnings import warn

import h5py
Expand Down Expand Up @@ -227,6 +235,83 @@ def load_from_disk_cache(file_path: Union[str, Path]) -> Tuple:
return data_list, target_list


@dataclass
class Aug_DiskCachedDataset(DiskCachedDataset):
"""Aug_DiskCachedDataset is a child class from DiskCachedDataset with further customizations to
handle augmented copies of a sample. The goal of this customization is to map the indices of
cached files (copy) to augmentation parameters. This is useful in a category of augmentations
where the range of parameter is rather disceret and non probabilistic, for instance an audio
sample is being augmented with noise and SNR can take only N=5 values. Passing copy_index to
augmentation Class as an init argument ensures that each copy will be a a distinct augmented
sample with a trackable parameter.

'generate_all' method generates all augmented vesions of a sample.
'generate_copy' method generates the missing variant (augmented version)



Therefore all transforms applied to the dataset are categorized by the keys: "pre_aug", "augmentations"
and "post_aug".

Args:
'all_transforms' is a dictionarty passed to this class containing information about all transforms.
"""

all_transforms: Optional[TypedDict] = None

def __post_init__(self):
self.pre_aug = self.all_transforms["pre_aug"]
self.aug = self.all_transforms["augmentations"]
self.post_aug = self.all_transforms["post_aug"]

def generate_all(self, item):
for copy in range(0, self.num_copies):
file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5")
try:
data, targets = load_from_disk_cache(file_path)
except (FileNotFoundError, OSError) as _:
self.generate_copy(item, copy)

def generate_copy(self, item, copy):
from torchvision.transforms import Compose

file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5")
# copy index is passed to augmentation (callable)
self.aug[0].aug_index = copy
augmentation = self.aug
self.dataset.transform = Compose(self.pre_aug + augmentation + self.post_aug)
data, targets = self.dataset[item]
save_to_disk_cache(data, targets, file_path=file_path, compress=self.compress)

def __getitem__(self, item) -> Tuple[object, object]:
if self.dataset is None and item >= self.n_samples:
raise IndexError(f"This dataset only has {self.n_samples} items.")

copy = random.randint(0, self.num_copies - 1)
file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5")
try:
data, targets = load_from_disk_cache(file_path)

except (FileNotFoundError, OSError) as _:
logging.info(
f"Data {item}: {file_path} not in cache, generating it now",
stacklevel=2,
)
self.generate_copy(item, copy)

# format might change during save to hdf5, i.e. tensors -> np arrays
# We load the sample here again to keep the output format consistent.
data, targets = load_from_disk_cache(file_path)

if self.transform is not None:
data = self.transform(data)
if self.target_transform is not None:
targets = self.target_transform(targets)
if self.transforms is not None:
data, targets = self.transforms(data, targets)
return data, targets


class CachedDataset(DiskCachedDataset):
"""Deprecated class that points to DiskCachedDataset for now but will be removed in a future
release.
Expand Down
Loading