Skip to content

Commit

Permalink
Merge pull request mlcommons#311 from znado/speech-ds
Browse files Browse the repository at this point in the history
Cleaning up speech datasets
  • Loading branch information
znado authored Feb 8, 2023
2 parents d7873a5 + 4e4b415 commit fc02196
Show file tree
Hide file tree
Showing 23 changed files with 213 additions and 266 deletions.
8 changes: 7 additions & 1 deletion algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CIFAR workload implemented in Jax."""

import functools
from typing import Dict, Iterator, Optional, Tuple
from typing import Any, Dict, Iterator, Optional, Tuple

from flax import jax_utils
from flax import linen as nn
Expand Down Expand Up @@ -190,3 +190,9 @@ def _eval_model(
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
11 changes: 10 additions & 1 deletion algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import contextlib
import random
from typing import Dict, Iterator, Optional, Tuple
from typing import Any, Dict, Iterator, Optional, Tuple

import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision import transforms
Expand Down Expand Up @@ -209,3 +210,11 @@ def _eval_model(
_, per_example_losses = self.loss_fn(batch['targets'], logits, weights)
loss = per_example_losses.sum()
return {'accuracy': accuracy, 'loss': loss}

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}
21 changes: 9 additions & 12 deletions algorithmic_efficiency/workloads/cifar/workload.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
"""CIFAR workload parent class."""

import abc
import math
from typing import Dict, Tuple
from typing import Any, Dict, Tuple

from absl import flags
import jax
import torch
import torch.distributed as dist

from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
import algorithmic_efficiency.random_utils as prng

FLAGS = flags.FLAGS
USE_PYTORCH_DDP, _, _, _ = pytorch_setup()


Expand Down Expand Up @@ -106,6 +104,12 @@ def _eval_model(
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
raise NotImplementedError

@abc.abstractmethod
def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""

def _eval_model_on_split(self,
split: str,
num_examples: int,
Expand Down Expand Up @@ -143,11 +147,4 @@ def _eval_model_on_split(self,
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

if FLAGS.framework == 'jax':
eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
eval_metrics)
return eval_metrics
elif USE_PYTORCH_DDP:
for metric in eval_metrics.values():
dist.all_reduce(metric)
return {k: float(v.item() / num_examples) for k, v in eval_metrics.items()}
return self._normalize_eval_metrics(num_examples, eval_metrics)
Original file line number Diff line number Diff line change
@@ -1,42 +1,40 @@
"""Data loader for pre-processed librispeech data."""
"""
Sharing the jax input pipeline slows down the data loading
and step times.
"""
import csv
from typing import Optional

from absl import logging
import numpy as np
import tensorflow as tf
import torch

from algorithmic_efficiency import spec

class LibriSpeechDataset(torch.utils.data.Dataset):

def get_librispeech_dataset(split_name: str,
data_dir: str,
shuffle_rng: spec.RandomState,
is_training: bool,
global_batch_size: int,
num_batches: Optional[int] = None):
"""Get the Librispeech dataset for a given split."""
splits = [split_name]
def __init__(self, split, data_dir):
super().__init__()
self.data_dir = data_dir
splits = split.split("+")
ids = []
for split in splits:
logging.info('Loading split = %s', split)
feat_csv = '{}/{}.csv'.format(data_dir, split)

if split_name.find('+') != -1:
splits = split_name.split('+')
with open(feat_csv, newline='') as csvfile:
data = list(csv.reader(csvfile))

ids = []
for example in data[1:]:
ids.append('{}/{}'.format(split, example[1]))
self.ids = ids

for split in splits:
logging.info(f'Loading split = {split}.')
feat_csv = f'{data_dir}/{split}.csv'
def __len__(self):
return len(self.ids)

with open(feat_csv, newline='') as csvfile:
data = list(csv.reader(csvfile))

for example in data[1:]:
ids.append(f'{split}/{example[1]}')

def load_data(example_id):
example_id = example_id.decode('utf-8')
audio = np.load(f'{data_dir}/{example_id}_audio.npy')
targets = np.load(f'{data_dir}/{example_id}_targets.npy')
def __getitem__(self, index):
example_id = self.ids[index]
data_dir = self.data_dir
audio = np.load('{}/{}_audio.npy'.format(data_dir, example_id))
targets = np.load('{}/{}_targets.npy'.format(data_dir, example_id))

audio_paddings = np.zeros_like(audio, dtype=np.float32)
audio_paddings = np.pad(
Expand All @@ -48,42 +46,8 @@ def load_data(example_id):
target_paddings, (0, 256 - target_paddings.shape[0]),
constant_values=1.0)
targets = np.pad(targets, (0, 256 - targets.shape[0]), constant_values=0)

return audio, audio_paddings, targets, target_paddings

def preprocess(example):
example_id = example['ids']

preprocessed_example = {}
audio, audio_paddings, targets, target_paddings = tf.numpy_function(
func=load_data,
inp=[example_id],
Tout=[tf.int64, tf.float32, tf.int32, tf.float32])

# Make batches of tuples of (tensor, padding)
preprocessed_example['inputs'] = (audio, audio_paddings)
preprocessed_example['targets'] = (targets, target_paddings)

return preprocessed_example

ds = tf.data.Dataset.from_tensor_slices({'ids': ids})
ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0])

ds = ds.map(preprocess, num_parallel_calls=10)

if is_training:
ds = ds.repeat()

if split in ['train', 'eval_train']:
ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0])

ds = ds.batch(global_batch_size, drop_remainder=is_training)

if is_training:
ds = ds.repeat()

if num_batches is not None:
ds = ds.take(num_batches)

ds = ds.prefetch(10)
return ds
audio = audio.astype(np.float32)
audio_paddings = audio_paddings.astype(np.float32)
targets = targets.astype(np.float32)
target_paddings = target_paddings.astype(np.float32)
return (audio, audio_paddings), (targets, target_paddings)
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
import math
from typing import Dict, Optional, Tuple

from absl import flags
from flax import jax_utils
import flax.linen as nn
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np
import optax
import torch

from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_conformer import metrics
from algorithmic_efficiency.workloads.librispeech_conformer import workload
from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \
LibriSpeechDataset
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \
models

FLAGS = flags.FLAGS


class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):

Expand All @@ -28,6 +29,56 @@ def __init__(self, tokenizer_vocab_path=None, use_specaug=True):
self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path)
self.use_specaug = use_specaug

def _build_input_queue(self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = False,
repeat_final_dataset: Optional[bool] = False,
num_batches: Optional[int] = None):
del data_rng
del cache
del repeat_final_dataset
del num_batches
train = False
if split == 'train':
split = 'train-clean-100+train-clean-360+train-other-500'
train = True
elif split == 'eval_train':
split = 'train-clean-100+train-clean-360+train-other-500'
elif split == 'validation':
split = 'dev-clean+dev-other'
elif split == 'test':
split = 'test-clean'

ds = LibriSpeechDataset(split=split, data_dir=data_dir)

dataloader = data_utils.cycle(
torch.utils.data.DataLoader(
ds,
batch_size=global_batch_size,
shuffle=train,
sampler=None,
num_workers=4,
prefetch_factor=10,
pin_memory=False,
drop_last=train,
))

for batch in iter(dataloader):
inputs, input_paddings = batch['inputs']
targets, target_paddings = batch['targets']

numpy_batch = {
'inputs': (inputs.numpy(), input_paddings.numpy()),
'targets': (targets.numpy(), target_paddings.numpy()),
}

padded_batch = data_utils.shard_and_maybe_pad_np(
numpy_batch, padding_value=1.0, global_batch_size=global_batch_size)
yield padded_batch

def init_model_fn(
self,
rng: spec.RandomState,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
from typing import Dict, Optional, Tuple

import jax
import torch
import torch.distributed as dist
import torch.distributed.nn as dist_nn
Expand All @@ -16,10 +15,10 @@
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.librispeech_conformer import metrics
from algorithmic_efficiency.workloads.librispeech_conformer import workload
from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \
LibriSpeechDataset
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \
model as conformer_model
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.libri_dataset import \
LibriSpeechDataset

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()

Expand Down Expand Up @@ -110,14 +109,16 @@ def model_fn(
return (logits, logits_paddings), None

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
del num_batches
cache: Optional[bool] = False,
repeat_final_dataset: Optional[bool] = False,
num_batches: Optional[int] = None):
del cache
del repeat_final_dataset
del num_batches

is_train = split == 'train'
if split == 'train':
Expand Down
Loading

0 comments on commit fc02196

Please sign in to comment.