From 9da5a24243a379887f3a2ab33071306436447e0f Mon Sep 17 00:00:00 2001 From: Harry He <68176557+HarryHe11@users.noreply.github.com> Date: Wed, 17 Jan 2024 16:07:39 +0800 Subject: [PATCH] Add preprocessing scripts for the librilight datasets (#107) * Add preprocessor for librilight dataset --- egs/datasets/README.md | 26 ++- preprocessors/librilight.py | 329 +++++++++++++++++++++++++++++++++ preprocessors/processor.py | 3 + utils/cut_by_vad.py | 105 +++++++++++ utils/mfa_prepare.py | 116 ++++++++++++ utils/whisper_transcription.py | 122 ++++++++++++ 6 files changed, 700 insertions(+), 1 deletion(-) create mode 100644 preprocessors/librilight.py create mode 100644 utils/cut_by_vad.py create mode 100644 utils/mfa_prepare.py create mode 100644 utils/whisper_transcription.py diff --git a/egs/datasets/README.md b/egs/datasets/README.md index ddc068ca..8e3d8cf8 100644 --- a/egs/datasets/README.md +++ b/egs/datasets/README.md @@ -7,6 +7,7 @@ Amphion support the following academic datasets (sort alphabetically): - [CSD](#csd) - [CustomSVCDataset](#customsvcdataset) - [KiSing](#kising) + - [LibriLight](#librilight) - [LibriTTS](#libritts) - [LJSpeech](#ljspeech) - [M4Singer](#m4singer) @@ -84,6 +85,30 @@ Download the official KiSing dataset [here](http://shijt.site/index.php/2021/05/ ┃ ┣ ... ``` +## LibriLight + +Download the official LibriLight dataset [here](https://github.com/facebookresearch/libri-light). The file structure looks like below: + +```plaintext +[LibriTTS dataset path] + ┣ small (Subset) + ┃ ┣ 100 {Speaker_ID} + ┃ ┃ ┣ sea_fairies_0812_librivox_64kb_mp3 {Chapter_ID} + ┃ ┃ ┃ ┣ 01_baum_sea_fairies_64kb.flac + ┃ ┃ ┃ ┣ 02_baum_sea_fairies_64kb.flac + ┃ ┃ ┃ ┣ 03_baum_sea_fairies_64kb.flac + ┃ ┃ ┃ ┣ 22_baum_sea_fairies_64kb.flac + ┃ ┃ ┃ ┣ 01_baum_sea_fairies_64kb.json + ┃ ┃ ┃ ┣ 02_baum_sea_fairies_64kb.json + ┃ ┃ ┃ ┣ 03_baum_sea_fairies_64kb.json + ┃ ┃ ┃ ┣ 22_baum_sea_fairies_64kb.json + ┃ ┃ ┃ ┣ ... + ┃ ┃ ┣ ... + ┃ ┣ ... + ┣ medium (Subset) + ┣ ... +``` + ## LibriTTS Download the official LibriTTS dataset [here](https://www.openslr.org/60/). The file structure looks like below: @@ -180,7 +205,6 @@ Download the official LibriTTS dataset [here](https://www.openslr.org/60/). The ┃ ┣ ... ``` - ## LJSpeech Download the official LJSpeech dataset [here](https://keithito.com/LJ-Speech-Dataset/). The file structure looks like below: diff --git a/preprocessors/librilight.py b/preprocessors/librilight.py new file mode 100644 index 00000000..eef5d6f9 --- /dev/null +++ b/preprocessors/librilight.py @@ -0,0 +1,329 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +from tqdm import tqdm +import os +import torchaudio +import torch + + +from utils.mfa_prepare import ( + process_wav_files, + get_wav_files, + filter_wav_files_by_length, +) +from utils.cut_by_vad import cut_segments +from utils.whisper_transcription import asr_main +from utils.util import has_existed + +import subprocess +import random +from collections import defaultdict +from glob import glob +import shutil + + +def librilight_statistics(data_dir): + """Get statistics for librilight dataset""" + distribution2speakers2utts = defaultdict(lambda: defaultdict(list)) + distribution_infos = glob(data_dir + "/*") + for distribution_info in distribution_infos: + distribution = distribution_info.split("/")[-1] + print(distribution) + speaker_infos = glob(distribution_info + "/*") + if len(speaker_infos) == 0: + continue + for speaker_info in speaker_infos: + speaker = speaker_info.split("/")[-1] + utts = glob(speaker_info + "/*.wav") + for utt in utts: + uid = utt.split("/")[-1].split(".")[0] + distribution2speakers2utts[distribution][speaker].append(uid) + return distribution2speakers2utts + + +def get_speakers_from_directory(directory): + return [ + d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d)) + ] + + +def split_dataset_by_speaker(base_dir, train_ratio=0.8, dev_ratio=0.1): + train_dir = os.path.join(base_dir, "train") + dev_dir = os.path.join(base_dir, "dev") + eval_dir = os.path.join(base_dir, "eval") + + # Check if dataset is already split + if has_existed(train_dir) or has_existed(dev_dir) or has_existed(eval_dir): + print("Dataset already split. Calculating speakers...") + train_speakers = get_speakers_from_directory(train_dir) + dev_speakers = get_speakers_from_directory(dev_dir) + eval_speakers = get_speakers_from_directory(eval_dir) + all_speakers = train_speakers + dev_speakers + eval_speakers + unique_speakers = list(set(all_speakers)) + unique_speakers.sort() + return unique_speakers + + # List all directories in the base directory + all_speakers = [ + d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) + ] + random.shuffle(all_speakers) + + # Calculate split sizes + total_speakers = len(all_speakers) + train_size = int(total_speakers * train_ratio) + dev_size = int(total_speakers * dev_ratio) + eval_size = total_speakers - train_size - dev_size + print("Total speakers:", total_speakers) + print("Train speakers:", train_size) + print("Dev speakers:", dev_size) + print("Eval speakers:", eval_size) + + # Split directories + train_speakers = all_speakers[:train_size] + dev_speakers = all_speakers[train_size : train_size + dev_size] + eval_speakers = all_speakers[train_size + dev_size :] + + # Function to move directories + def move_speakers(speakers, target_dir): + for speaker in speakers: + shutil.move( + os.path.join(base_dir, speaker), os.path.join(target_dir, speaker) + ) + + # Move directories + print("Moving directories...") + print("Moving Train speakers...") + move_speakers(train_speakers, train_dir) + print("Moving Dev speakers...") + move_speakers(dev_speakers, dev_dir) + print("Moving Eval speakers...") + move_speakers(eval_speakers, eval_dir) + + unique_speakers = list(set(all_speakers)) + unique_speakers.sort() + return unique_speakers + + +def save_meta_data(save_dir, processed_dir, distribution2speakers2utts, speakers): + """Save metadata for librilight dataset""" + os.makedirs(save_dir, exist_ok=True) + train_output_file = os.path.join(save_dir, "train.json") + valid_output_file = os.path.join(save_dir, "dev.json") + test_output_file = os.path.join(save_dir, "eval.json") + singer_dict_file = os.path.join(save_dir, "singers.json") + utt2singer_file = os.path.join(save_dir, "utt2singer") + utt2singer = open(utt2singer_file, "w") + if has_existed(train_output_file): + print("Metadata already exists. Skipping...") + return + + train = [] + test = [] + valid = [] + + train_index_count = 0 + test_index_count = 0 + valid_index_count = 0 + + train_total_duration = 0 + test_total_duration = 0 + valid_total_duration = 0 + + # Save metadata + for distribution, speakers2utts in tqdm(distribution2speakers2utts.items()): + for speaker, utts in tqdm(speakers2utts.items()): + for chosen_uid in utts: + res = { + "Dataset": "librilight", + "Singer": speaker, + "Uid": "{}#{}#{}".format(distribution, speaker, chosen_uid), + } + res["Path"] = "{}/{}/{}.wav".format(distribution, speaker, chosen_uid) + res["Path"] = os.path.join(processed_dir, res["Path"]) + assert os.path.exists(res["Path"]) + + text_file_path = os.path.join( + processed_dir, + distribution, + speaker, + chosen_uid + ".txt", + ) + with open(text_file_path, "r") as f: + lines = f.readlines() + assert len(lines) == 1 + text = lines[0].strip() + res["Text"] = text + + waveform, sample_rate = torchaudio.load(res["Path"]) + duration = waveform.size(-1) / sample_rate + res["Duration"] = duration + + if "train" in distribution: + res["index"] = train_index_count + train_total_duration += duration + train.append(res) + train_index_count += 1 + elif "dev" in distribution: + res["index"] = valid_index_count + valid_total_duration += duration + valid.append(res) + valid_index_count += 1 + elif "eval" in distribution: + res["index"] = test_index_count + test_total_duration += duration + test.append(res) + test_index_count += 1 + utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"])) + print("Done!") + print( + "Utterance count: train = {}, dev = {}, eval = {}".format( + len(train), len(valid), len(test) + ) + ) + print( + "#Train duration= {}, #Dev duration= {}, #Eval duration= {}".format( + train_total_duration / 3600, + valid_total_duration / 3600, + test_total_duration / 3600, + ) + ) + with open(train_output_file, "w") as f: + json.dump(train, f, indent=4, ensure_ascii=False) + with open(test_output_file, "w") as f: + json.dump(test, f, indent=4, ensure_ascii=False) + with open(valid_output_file, "w") as f: + json.dump(valid, f, indent=4, ensure_ascii=False) + utt2singer.close() + singer_lut = {name: i for i, name in enumerate(speakers)} + with open(singer_dict_file, "w") as f: + json.dump(singer_lut, f, indent=4, ensure_ascii=False) + print("Metadata saved to", save_dir) + + +def main(output_path, dataset_path, cfg): + """Preprocess librilight dataset""" + n_cpus = cfg.n_cpus # number of cpus to use for preprocessing + n_gpus = cfg.n_gpus # number of gpus to use for transcription + cut_length = cfg.cut_length # target length of utterance in seconds + max_length = cfg.max_length # max length of utterance in seconds + + # MFA files + mfa_config_path = cfg.mfa_config_path # path to mfa config file + mfa_dict_path = cfg.mfa_dict_path # path to mfa dict file + mfa_model_path = cfg.mfa_model_path # path to mfa model file + + # check if mfa files exist + if ( + not os.path.exists(mfa_dict_path) + or not os.path.exists(mfa_model_path) + or not os.path.exists(mfa_config_path) + ): + raise Exception("MFA files not found.") + + # Whisper model id + model_id = cfg.whisper_model_id # id of whisper model to use for transcription + + subsets = [ + d + for d in os.listdir(dataset_path) + if ( + os.path.isdir(os.path.join(dataset_path, d)) + and d in ["tiny", "small", "medium", "large"] + ) + ] + print("Found subsets:", subsets) + + if len(subsets) == 0: + print("No subsets found. Exiting...") + return + # Preprocess each subset + for subset in subsets: + # Construct paths based on the base path + print("Pre-proccessing Libri-light subset:", subset) + raw_dir = f"{dataset_path}/{subset}" + save_dir = f"{output_path}/{subset}" + processed_dir = f"{dataset_path}/processed/{subset}" + os.makedirs(processed_dir, exist_ok=True) + os.makedirs(save_dir, exist_ok=True) + + # Step 1: Segmentation + print("-" * 10) + print("Step 1: Segmentation") + print("Cutting audio files...") + + cut_segments(raw_dir, processed_dir, cut_length, n_cpus) + + # Steps 2 & 3: Filter and Preprocess + print("-" * 10) + print("Step 2 & 3: Filter and Preprocess") + print("Filtering and preprocessing audio files...") + + wav_files = get_wav_files(processed_dir) + filtered_wav_files = filter_wav_files_by_length(wav_files, max_length) + process_wav_files(filtered_wav_files, processed_dir, n_cpus) + + # Step 4 & 5: Transcription & Text-preprocess + print("-" * 10) + print("Step 4 & 5: Transcription & Text-preprocess") + print("Transcribing audio files...") + + n_gpus = min(n_gpus, torch.cuda.device_count()) + asr_main(processed_dir, n_gpus, model_id) + + # Step 6: MFA Align + print("-" * 10) + print("Step 6: MFA Align") + print("Aligning audio files...") + + command = [ + "mfa", + "align", + "-v", + "-j", + str(n_cpus), + "-c", + mfa_config_path, + processed_dir, + mfa_dict_path, + mfa_model_path, + processed_dir, + "--output_format", + "long_textgrid", + "--clean", + "--overwrite", + ] + subprocess.run(command, text=True) + + # Step 7: train/dev/eval split + print("-" * 10) + print("Step 7: train/dev/eval split") + print("Splitting dataset by speaker...") + + speakers = split_dataset_by_speaker(processed_dir) + + # Step 8: Statistics + print("-" * 10) + print("Step 8: Statistics") + print("Calculating statistics...") + + distribution2speakers2utts = librilight_statistics(processed_dir) + + # Step 9: Save metadata + print("-" * 10) + print("Step 9: Save metadata") + print("Preparing Metadata for Librilight...") + + save_meta_data(save_dir, processed_dir, distribution2speakers2utts, speakers) + print("Preprocessing subset", subset, "done!") + print("-" * 10) + + +if __name__ == "__main__": + dataset_path = "/path/to/dataset/librilight" + output_path = "/path/to/output" + main(output_path, dataset_path) diff --git a/preprocessors/processor.py b/preprocessors/processor.py index ce6638e0..1a1d0362 100644 --- a/preprocessors/processor.py +++ b/preprocessors/processor.py @@ -28,6 +28,7 @@ customsvcdataset, vocalist, ljspeech_vocoder, + librilight, ) @@ -90,6 +91,8 @@ def preprocess_dataset( cocoeval.main(output_path, dataset_path) if dataset == "vocalist": vocalist.main(output_path, dataset_path) + if dataset == "librilight": + librilight.main(output_path, dataset_path, cfg) def prepare_align(dataset, dataset_path, cfg, output_path): diff --git a/utils/cut_by_vad.py b/utils/cut_by_vad.py new file mode 100644 index 00000000..0d41a4a1 --- /dev/null +++ b/utils/cut_by_vad.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" This code is modified from https://github.com/facebookresearch/libri-light/blob/main/data_preparation/cut_by_vad.py""" +import pathlib +import soundfile as sf +import numpy as np +import json +import multiprocessing +import tqdm + + +def save(seq, fname, index, extension): + """save audio sequences to file""" + output = np.hstack(seq) + file_name = fname.parent / (fname.stem + f"_{index:04}{extension}") + fname.parent.mkdir(exist_ok=True, parents=True) + sf.write(file_name, output, samplerate=16000) + + +def cut_sequence(path, vad, path_out, target_len_sec, out_extension): + """cut audio sequences based on VAD""" + data, samplerate = sf.read(path) + + assert len(data.shape) == 1 + assert samplerate == 16000 + + to_stitch = [] + length_accumulated = 0.0 + + i = 0 + # Iterate over VAD segments + for start, end in vad: + start_index = int(start * samplerate) + end_index = int(end * samplerate) + slice = data[start_index:end_index] + + # Save slices that exceed the target length or if there's already accumulated audio + if ( + length_accumulated + (end - start) > target_len_sec + and length_accumulated > 0 + ): + save(to_stitch, path_out, i, out_extension) + to_stitch = [] + i += 1 + length_accumulated = 0 + + # Add the current slice to the list to be stitched + to_stitch.append(slice) + length_accumulated += end - start + + # Save any remaining slices + if to_stitch: + save(to_stitch, path_out, i, out_extension) + + +def cut_book(task): + """process each book in the dataset""" + path_book, root_out, target_len_sec, extension = task + + speaker = pathlib.Path(path_book.parent.name) + + for i, meta_file_path in enumerate(path_book.glob("*.json")): + with open(meta_file_path, "r") as f: + meta = json.loads(f.read()) + book_id = meta["book_meta"]["id"] + vad = meta["voice_activity"] + + sound_file = meta_file_path.parent / (meta_file_path.stem + ".flac") + + path_out = root_out / speaker / book_id / (meta_file_path.stem) + cut_sequence(sound_file, vad, path_out, target_len_sec, extension) + + +def cut_segments( + input_dir, output_dir, target_len_sec=30, n_process=32, out_extension=".wav" +): + """Main function to cut segments from audio files""" + + pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) + list_dir = pathlib.Path(input_dir).glob("*/*") + list_dir = [x for x in list_dir if x.is_dir()] + + print(f"{len(list_dir)} directories detected") + print(f"Launching {n_process} processes") + + # Create tasks for multiprocessing + tasks = [ + (path_book, output_dir, target_len_sec, out_extension) for path_book in list_dir + ] + + # Process tasks in parallel using multiprocessing + with multiprocessing.Pool(processes=n_process) as pool: + for _ in tqdm.tqdm(pool.imap_unordered(cut_book, tasks), total=len(tasks)): + pass + + +if __name__ == "__main__": + input_dir = "/path/to/input_dir" + output_dir = "/path/to/output_dir" + target_len_sec = 10 + n_process = 16 + cut_segments(input_dir, output_dir, target_len_sec, n_process) diff --git a/utils/mfa_prepare.py b/utils/mfa_prepare.py new file mode 100644 index 00000000..b79ba862 --- /dev/null +++ b/utils/mfa_prepare.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" This code is modified from https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/performance.html""" + +import os +import subprocess +from multiprocessing import Pool +from tqdm import tqdm +import torchaudio +from pathlib import Path + + +def remove_empty_dirs(path): + """remove empty directories in a given path""" + # Check if the given path is a directory + if not os.path.isdir(path): + print(f"{path} is not a directory") + return + + # Walk through all directories and subdirectories + for root, dirs, _ in os.walk(path, topdown=False): + for dir in dirs: + dir_path = os.path.join(root, dir) + # Check if the directory is empty + if not os.listdir(dir_path): + os.rmdir(dir_path) # "Removed empty directory + + +def process_single_wav_file(task): + """process a single wav file""" + wav_file, output_dir = task + speaker_id, book_name, filename = Path(wav_file).parts[-3:] + + output_book_dir = Path(output_dir, speaker_id) + output_book_dir.mkdir(parents=True, exist_ok=True) + new_filename = f"{speaker_id}_{book_name}_{filename}" + + new_wav_file = Path(output_book_dir, new_filename) + command = [ + "ffmpeg", + "-nostdin", + "-hide_banner", + "-loglevel", + "error", + "-nostats", + "-i", + wav_file, + "-acodec", + "pcm_s16le", + "-ar", + "16000", + new_wav_file, + ] + subprocess.check_call( + command + ) # Run the command to convert the file to 16kHz and 16-bit PCM + os.remove(wav_file) + + +def process_wav_files(wav_files, output_dir, n_process): + """process wav files in parallel""" + tasks = [(wav_file, output_dir) for wav_file in wav_files] + print(f"Processing {len(tasks)} files") + with Pool(processes=n_process) as pool: + for _ in tqdm( + pool.imap_unordered(process_single_wav_file, tasks), total=len(tasks) + ): + pass + print("Removing empty directories...") + remove_empty_dirs(output_dir) + print("Done!") + + +def get_wav_files(dataset_path): + """get all wav files in the dataset""" + wav_files = [] + for speaker_id in os.listdir(dataset_path): + speaker_dir = os.path.join(dataset_path, speaker_id) + if not os.path.isdir(speaker_dir): + continue + for book_name in os.listdir(speaker_dir): + book_dir = os.path.join(speaker_dir, book_name) + if not os.path.isdir(book_dir): + continue + for file in os.listdir(book_dir): + if file.endswith(".wav"): + wav_files.append(os.path.join(book_dir, file)) + print("Found {} wav files".format(len(wav_files))) + return wav_files + + +def filter_wav_files_by_length(wav_files, max_len_sec=15): + """filter wav files by length""" + print("original wav files: {}".format(len(wav_files))) + filtered_wav_files = [] + for audio_file in wav_files: + metadata = torchaudio.info(str(audio_file)) + audio_length = metadata.num_frames / metadata.sample_rate + if audio_length <= max_len_sec: + filtered_wav_files.append(audio_file) + else: + os.remove(audio_file) + print("filtered wav files: {}".format(len(filtered_wav_files))) + return filtered_wav_files + + +if __name__ == "__main__": + dataset_path = "/path/to/output/directory" + n_process = 16 + max_len_sec = 15 + wav_files = get_wav_files(dataset_path) + filtered_wav_files = filter_wav_files_by_length(wav_files, max_len_sec) + process_wav_files(filtered_wav_files, dataset_path, n_process) diff --git a/utils/whisper_transcription.py b/utils/whisper_transcription.py new file mode 100644 index 00000000..98126987 --- /dev/null +++ b/utils/whisper_transcription.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import pathlib +import string +import time +from multiprocessing import Pool, Value, Lock +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor +import torch +import whisper + +processed_files_count = Value("i", 0) # count of processed files +lock = Lock() # lock for the count + + +def preprocess_text(text): + """Preprocess text after ASR""" + return text.lower().translate(str.maketrans("", "", string.punctuation)) + + +def transcribe_audio(model, processor, audio_file, device): + """Transcribe audio file""" + audio = whisper.load_audio(audio_file) # load from path + audio = whisper.pad_or_trim(audio) # default 30 seconds + inputs = whisper.log_mel_spectrogram(audio).to( + device=device + ) # convert to spectrogram + inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) # add batch dimension + + outputs = model.generate( + inputs=inputs, max_new_tokens=128 + ) # generate transcription + transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ + 0 + ] # decode + transcription_processed = preprocess_text(transcription) # preprocess + return transcription_processed + + +def write_transcription(audio_file, transcription): + """Write transcription to txt file""" + txt_file = audio_file.with_suffix(".txt") + with open(txt_file, "w") as file: + file.write(transcription) + + +def init_whisper(model_id, device): + """Initialize whisper model and processor""" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + print(f"Loading model {model_id}") # model_id = "distil-whisper/distil-large-v2" + distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False + ) + distil_model = distil_model.to(device) + processor = AutoProcessor.from_pretrained(model_id) + return distil_model, processor + + +def asr_wav_files(file_list, gpu_id, total_files, model_id): + """Transcribe wav files in a list""" + device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" + whisper_model, processor = init_whisper(model_id, device) + print(f"Processing on {device} starts") + start_time = time.time() + for audio_file in file_list: + try: + transcription = transcribe_audio( + whisper_model, processor, audio_file, device + ) + write_transcription(audio_file, transcription) + with lock: + processed_files_count.value += 1 + if processed_files_count.value % 5 == 0: + current_time = time.time() + avg_time_per_file = (current_time - start_time) / ( + processed_files_count.value + ) + remaining_files = total_files - processed_files_count.value + estimated_time_remaining = avg_time_per_file * remaining_files + remaining_time_formatted = time.strftime( + "%H:%M:%S", time.gmtime(estimated_time_remaining) + ) + print( + f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" + ) + except Exception as e: + print(f"Error processing file {audio_file}: {e}") + + +def asr_main(input_dir, num_gpus, model_id): + """Transcribe wav files in a directory""" + num_processes = min(num_gpus, os.cpu_count()) + print(f"Using {num_processes} GPUs for transcription") + wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) + total_files = len(wav_files) + print(f"Found {total_files} wav files in {input_dir}") + files_per_process = len(wav_files) // num_processes + print(f"Processing {files_per_process} files per process") + with Pool(num_processes) as p: + p.starmap( + asr_wav_files, + [ + ( + wav_files[i * files_per_process : (i + 1) * files_per_process], + i % num_gpus, + total_files, + model_id, + ) + for i in range(num_processes) + ], + ) + print("Done!") + + +if __name__ == "__main__": + input_dir = "/path/to/output/directory" + num_gpus = 2 + model_id = "distil-whisper/distil-large-v2" + asr_main(input_dir, num_gpus, model_id)