diff --git a/bins/tts/train.py b/bins/tts/train.py index 7b4e2ec9..785a0d36 100644 --- a/bins/tts/train.py +++ b/bins/tts/train.py @@ -97,9 +97,9 @@ def main(): # VALLETrainer.add_arguments(parser) args = parser.parse_args() cfg = load_config(args.config) - + # Data Augmentation - if hasattr(cfg, 'preprocess'): + if hasattr(cfg, "preprocess"): if hasattr(cfg.preprocess, "data_augment"): if ( type(cfg.preprocess.data_augment) == list @@ -108,14 +108,26 @@ def main(): new_datasets_list = [] for dataset in cfg.preprocess.data_augment: new_datasets = [ - f"{dataset}_pitch_shift" if cfg.preprocess.use_pitch_shift else None, + ( + f"{dataset}_pitch_shift" + if cfg.preprocess.use_pitch_shift + else None + ), ( f"{dataset}_formant_shift" if cfg.preprocess.use_formant_shift else None ), - f"{dataset}_equalizer" if cfg.preprocess.use_equalizer else None, - f"{dataset}_time_stretch" if cfg.preprocess.use_time_stretch else None, + ( + f"{dataset}_equalizer" + if cfg.preprocess.use_equalizer + else None + ), + ( + f"{dataset}_time_stretch" + if cfg.preprocess.use_time_stretch + else None + ), ] new_datasets_list.extend(filter(None, new_datasets)) cfg.dataset.extend(new_datasets_list)