Skip to content

Commit

Permalink
add code for parallel preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
mboyanov committed Sep 5, 2024
1 parent 0af91f1 commit cf03e0b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
7 changes: 2 additions & 5 deletions bins/svc/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ def extract_acoustic_features(dataset, output_path, cfg, n_workers=1):
with open(dataset_file, "r") as f:
metadata.extend(json.load(f))

# acoustic_extractor.extract_utt_acoustic_features_parallel(
# metadata, dataset_output, cfg, n_workers=n_workers
# )
acoustic_extractor.extract_utt_acoustic_features_serial(
metadata, dataset_output, cfg
acoustic_extractor.extract_utt_acoustic_features_parallel(
metadata, dataset_output, cfg, n_workers=n_workers
)


Expand Down
33 changes: 23 additions & 10 deletions processors/acoustic_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from functools import partial

import torch
import numpy as np

Expand All @@ -23,6 +24,7 @@
extract_linear_features,
extract_mel_features_tts,
)
from concurrent.futures import as_completed, ProcessPoolExecutor

ZERO = 1e-12

Expand All @@ -39,15 +41,26 @@ def extract_utt_acoustic_features_parallel(metadata, dataset_output, cfg, n_work
Returns:
list: acoustic features
"""
for utt in tqdm(metadata):
if cfg.task_type == "tts":
extract_utt_acoustic_features_tts(dataset_output, cfg, utt)
if cfg.task_type == "svc":
extract_utt_acoustic_features_svc(dataset_output, cfg, utt)
if cfg.task_type == "vocoder":
extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt)
if cfg.task_type == "tta":
extract_utt_acoustic_features_tta(dataset_output, cfg, utt)
extractor = None
if cfg.task_type == "tts":
extractor = partial(extract_utt_acoustic_features_tts, dataset_output, cfg)
if cfg.task_type == "svc":
extractor = partial(extract_utt_acoustic_features_svc, dataset_output, cfg)
if cfg.task_type == "vocoder":
extractor = partial(extract_utt_acoustic_features_vocoder, dataset_output, cfg)
if cfg.task_type == "tta":
extractor = partial(extract_utt_acoustic_features_tta, dataset_output, cfg)

with ProcessPoolExecutor(max_workers=n_workers) as pool:
future_to_utt = {
pool.submit(extractor, utt): utt for utt in metadata
}
for future in tqdm(as_completed(future_to_utt), total=len(future_to_utt)):
utt = future_to_utt[future]
try:
future.result()
except Exception as exc:
print('%r generated an exception: %s' % (utt, exc))


def avg_phone_feature(feature, duration, interpolation=False):
Expand Down

0 comments on commit cf03e0b

Please sign in to comment.