Skip to content

Commit

Permalink
Fix an4 download, and dockerfile (SeanNaren#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Apr 22, 2022
1 parent db52612 commit d169c84
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 153 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
FROM pytorch/pytorch:latest
FROM nvcr.io/nvidia/pytorch:22.03-py3
ENV DEBIAN_FRONTEND=noninteractive

ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH

WORKDIR /workspace/
Expand Down
6 changes: 3 additions & 3 deletions configs/an4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ data:
num_workers: 8
trainer:
max_epochs: 70
gpus: 1
accelerator: 'auto'
devices: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
strategy: ddp
enable_checkpointing: True
checkpoint:
save_top_k: 1
Expand Down
6 changes: 3 additions & 3 deletions configs/commonvoice.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ data:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
accelerator: 'auto'
devices: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
strategy: ddp
enable_checkpointing: True
checkpoint:
save_top_k: 1
Expand Down
6 changes: 3 additions & 3 deletions configs/librispeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ data:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
accelerator: 'auto'
devices: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
strategy: ddp
enable_checkpointing: True
checkpoint:
save_top_k: 1
Expand Down
6 changes: 3 additions & 3 deletions configs/tedlium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ data:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
accelerator: 'auto'
devices: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
strategy: ddp
enable_checkpointing: True
checkpoint:
save_top_k: 1
Expand Down
134 changes: 5 additions & 129 deletions data/an4.py
Original file line number Diff line number Diff line change
@@ -1,145 +1,24 @@
import argparse
import os
import io
import shutil
import tarfile

from sklearn.model_selection import train_test_split
import wget

from deepspeech_pytorch.data.data_opts import add_data_opts
from deepspeech_pytorch.data.utils import create_manifest


def _format_training_data(root_path,
val_fraction,
sample_rate,
target_dir):
wav_path = root_path + 'wav/'
file_ids_path = root_path + 'etc/an4_train.fileids'
transcripts_path = root_path + 'etc/an4_train.transcription'
root_wav_path = wav_path + 'an4_clstk'

_convert_audio_to_wav(an4_audio_path=root_wav_path,
sample_rate=sample_rate)
file_ids, transcripts = _retrieve_file_ids_and_transcripts(file_ids_path, transcripts_path)

split_files = train_test_split(file_ids, transcripts, test_size=val_fraction)
train_file_ids, val_file_ids, train_transcripts, val_transcripts = split_files

_save_wav_transcripts(data_type='train',
file_ids=train_file_ids,
transcripts=train_transcripts,
wav_dir=wav_path,
target_dir=target_dir)
_save_wav_transcripts(data_type='val',
file_ids=val_file_ids,
transcripts=val_transcripts,
wav_dir=wav_path,
target_dir=target_dir)


def _format_test_data(root_path,
sample_rate,
target_dir):
wav_path = root_path + 'wav/'
file_ids_path = root_path + 'etc/an4_test.fileids'
transcripts_path = root_path + 'etc/an4_test.transcription'
root_wav_path = wav_path + 'an4test_clstk'

_convert_audio_to_wav(an4_audio_path=root_wav_path,
sample_rate=sample_rate)
file_ids, transcripts = _retrieve_file_ids_and_transcripts(file_ids_path, transcripts_path)

_save_wav_transcripts(data_type='test',
file_ids=file_ids,
transcripts=transcripts,
wav_dir=wav_path,
target_dir=target_dir)


def _save_wav_transcripts(data_type,
file_ids,
transcripts,
wav_dir,
target_dir):
data_path = os.path.join(target_dir, data_type + '/an4/')
new_transcript_dir = data_path + '/txt/'
new_wav_dir = data_path + '/wav/'

os.makedirs(new_transcript_dir)
os.makedirs(new_wav_dir)

_save_files(file_ids=file_ids,
transcripts=transcripts,
wav_dir=wav_dir,
new_wav_dir=new_wav_dir,
new_transcript_dir=new_transcript_dir)


def _convert_audio_to_wav(an4_audio_path, sample_rate):
with os.popen('find %s -type f -name "*.raw"' % an4_audio_path) as pipe:
for line in pipe:
raw_path = line.strip()
new_path = line.replace('.raw', '.wav').strip()
cmd = 'sox -t raw -r %d -b 16 -e signed-integer -B -c 1 \"%s\" \"%s\"' % (
sample_rate, raw_path, new_path)
os.system(cmd)


def _save_files(file_ids, transcripts, wav_dir, new_wav_dir, new_transcript_dir):
for file_id, transcript in zip(file_ids, transcripts):
path = wav_dir + file_id.strip() + '.wav'
filename = path.split('/')[-1]
extracted_transcript = _process_transcript(transcript)
new_path = new_wav_dir + filename
text_path = new_transcript_dir + filename.replace('.wav', '.txt')
with io.FileIO(text_path, "w") as file:
file.write(extracted_transcript.encode('utf-8'))
current_path = os.path.abspath(path)
shutil.copy(current_path, new_path)
os.remove(current_path)


def _retrieve_file_ids_and_transcripts(file_id_path, transcripts_path):
with open(file_id_path, 'r') as f:
file_ids = f.readlines()
with open(transcripts_path, 'r') as t:
transcripts = t.readlines()
return file_ids, transcripts


def _process_transcript(transcript):
"""
Removes tags found in AN4.
"""
extracted_transcript = transcript.split('(')[0].strip("<s>").split('<')[0].strip().upper()
return extracted_transcript


def download_an4(target_dir: str,
manifest_dir: str,
min_duration: float,
max_duration: float,
val_fraction: float,
sample_rate: int,
num_workers: int):
root_path = 'an4/'
raw_tar_path = 'an4_raw.bigendian.tar.gz'
raw_tar_path = 'an4.tar.gz'
if not os.path.exists(raw_tar_path):
wget.download('http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz')
tar = tarfile.open('an4_raw.bigendian.tar.gz')
tar.extractall()
wget.download('https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/an4.tar.gz')
tar = tarfile.open('an4.tar.gz')
os.makedirs(target_dir, exist_ok=True)
_format_training_data(root_path=root_path,
val_fraction=val_fraction,
sample_rate=sample_rate,
target_dir=target_dir)
_format_test_data(root_path=root_path,
sample_rate=sample_rate,
target_dir=target_dir)
shutil.rmtree(root_path)
os.remove('an4_raw.bigendian.tar.gz')
tar.extractall(target_dir)
train_path = target_dir + '/train/'
val_path = target_dir + '/val/'
test_path = target_dir + '/test/'
Expand Down Expand Up @@ -167,15 +46,12 @@ def download_an4(target_dir: str,
parser = argparse.ArgumentParser(description='Processes and downloads an4.')
parser = add_data_opts(parser)
parser.add_argument('--target-dir', default='an4_dataset/', help='Path to save dataset')
parser.add_argument('--val-fraction', default=0.1, type=float,
help='Number of files in the training set to use as validation.')
args = parser.parse_args()
assert args.sample_rate == 16000, "AN4 only supports sample rate of 16000 currently."
download_an4(
target_dir=args.target_dir,
manifest_dir=args.manifest_dir,
min_duration=args.min_duration,
max_duration=args.max_duration,
val_fraction=args.val_fraction,
sample_rate=args.sample_rate,
num_workers=args.num_workers
)
1 change: 0 additions & 1 deletion deepspeech_pytorch/configs/lightning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,3 @@ class TrainerConf:
reload_dataloaders_every_n_epochs: int = 0
multiple_trainloader_mode: str = "max_size_cycle"
stochastic_weight_avg: bool = False
terminate_on_nan: Optional[bool] = None
3 changes: 1 addition & 2 deletions deepspeech_pytorch/configs/train_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, List

from hydra_configs.pytorch_lightning.callbacks import ModelCheckpointConf
from omegaconf import MISSING

from deepspeech_pytorch.configs.lightning_config import TrainerConf
from deepspeech_pytorch.configs.lightning_config import TrainerConf, ModelCheckpointConf
from deepspeech_pytorch.enums import SpectrogramWindow, RNNType

defaults = [
Expand Down
8 changes: 5 additions & 3 deletions deepspeech_pytorch/loader/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class DeepSpeechDataModule(pl.LightningDataModule):
def __init__(self,
labels: list,
data_cfg: DataConfig,
normalize: bool,
is_distributed: bool):
normalize: bool):
super().__init__()
self.train_path = to_absolute_path(data_cfg.train_path)
self.val_path = to_absolute_path(data_cfg.val_path)
Expand All @@ -21,7 +20,10 @@ def __init__(self,
self.spect_cfg = data_cfg.spect
self.aug_cfg = data_cfg.augmentation
self.normalize = normalize
self.is_distributed = is_distributed

@property
def is_distributed(self):
return self.trainer.devices > 1

def train_dataloader(self):
train_dataset = self._create_dataset(self.train_path)
Expand Down
1 change: 0 additions & 1 deletion deepspeech_pytorch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def train(cfg: DeepSpeechConfig):
labels=labels,
data_cfg=cfg.data,
normalize=True,
is_distributed=cfg.trainer.gpus > 1
)

model = DeepSpeech(
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fairscale
flask
hydra-core
jupyter
Expand All @@ -17,4 +16,4 @@ torchaudio
torchelastic
tqdm
wget
git+https://github.com/romesco/hydra-lightning/#subdirectory=hydra-configs-pytorch-lightning
git+https://github.com/romesco/hydra-lightning/#subdirectory=hydra-configs-pytorch-lightning
2 changes: 0 additions & 2 deletions tests/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def download_data(self,
manifest_dir=cfg.manifest_dir,
min_duration=cfg.min_duration,
max_duration=cfg.max_duration,
val_fraction=cfg.val_fraction,
sample_rate=cfg.sample_rate,
num_workers=cfg.num_workers
)

Expand Down

0 comments on commit d169c84

Please sign in to comment.