From 2eae1f07b09124af5632572e7886811e98294f54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20Li=20=28=E6=9D=8E=E5=87=AF=29?= Date: Wed, 23 Oct 2024 20:17:52 +0800 Subject: [PATCH] Update training code --- README.md | 22 +++++++ enhancement/audio_train.py | 121 +++++++++++++++++++++++++++++++++++ separation/audio_train.py | 127 +++++++++++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+) create mode 100644 enhancement/audio_train.py create mode 100644 separation/audio_train.py diff --git a/README.md b/README.md index d983f1d..2c21da6 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ We introduce SonicSim, a synthetic toolkit designed to generate highly customiza ## 🔥 News +- [2024-10-23] We release the training code for speech separation and enhancement models on the [SonicSet dataset](#sonicset-dataset). + - [2024-10-03] We release the paper on [arxiv](https://arxiv.org/abs/2410.01481) - [2024-10-01] We release the [Real-world speech separation dataset](#real-world-data), which aims to evaluate the performance of speech separation models in real-world scenarios. @@ -90,6 +92,26 @@ conda env create -f SonicSim/torch-2.0.yml conda activate SonicSim ``` +### Training Speech Separation and Enhancement Models + +#### Training Speech Separation Models + +Navigate to the `separation` directory and run the training script with the specified configuration file: + +```bash +cd separation +python train.py --conf_dir=../sep-checkpoints/TFGNet-Noise/config.yaml +``` + +#### Training Speech Enhancement Models + +Navigate to the `enhancement` directory and run the training script with the specified configuration file: + +```bash +cd enhancement +python train.py --conf_dir=../enh-checkpoints/TaylorSENet-Noise/config.yaml +``` + ### Download Checkpoints Please check the contents of README.md in the [sep-checkpoints](https://github.com/JusperLee/SonicSim/tree/main/sep-checkpoints) and [enh-checkpoints](https://github.com/JusperLee/SonicSim/tree/main/enh-checkpoints) folders, download the appropriate pre-trained models in [Release](https://github.com/JusperLee/SonicSim/releases/tag/v1.0) and unzip them into the appropriate folders. diff --git a/enhancement/audio_train.py b/enhancement/audio_train.py new file mode 100644 index 0000000..1ddb1b3 --- /dev/null +++ b/enhancement/audio_train.py @@ -0,0 +1,121 @@ +import json +from typing import Any, Dict, List, Optional, Tuple +import os +from omegaconf import OmegaConf +import argparse +import pytorch_lightning as pl +import torch +torch.set_float32_matmul_precision("highest") +import hydra +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +# from pytorch_lightning.loggers import Logger +from omegaconf import DictConfig +import look2hear.system +import look2hear.datas +import look2hear.losses +from look2hear.utils import RankedLogger, instantiate, print_only +import warnings +warnings.filterwarnings("ignore") + + +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # instantiate datamodule + print_only(f"Instantiating datamodule <{cfg.datas._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas) + + # instantiate model + print_only(f"Instantiating AudioNet <{cfg.model._target_}>") + model: torch.nn.Module = hydra.utils.instantiate(cfg.model) + + # instantiate optimizer + print_only(f"Instantiating optimizer <{cfg.optimizer._target_}>") + optimizer: torch.optim = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) + # optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr) + + # instantiate scheduler + if cfg.get("scheduler"): + print_only(f"Instantiating scheduler <{cfg.scheduler._target_}>") + scheduler: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer) + else: + scheduler = None + + # instantiate loss + print_only(f"Instantiating loss <{cfg.loss._target_}>") + loss: torch.nn.Module = hydra.utils.instantiate(cfg.loss) + + # instantiate metrics + print_only(f"Instantiating metrics <{cfg.metrics._target_}>") + metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics) + # instantiate system + print_only(f"Instantiating system <{cfg.system._target_}>") + system: LightningModule = hydra.utils.instantiate( + cfg.system, + model=model, + loss_func=loss, + metrics=metrics, + optimizer=optimizer, + scheduler=scheduler + ) + + # instantiate callbacks + callbacks: List[Callback] = [] + if cfg.get("early_stopping"): + print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>") + callbacks.append(hydra.utils.instantiate(cfg.early_stopping)) + if cfg.get("checkpoint"): + print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>") + checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint) + callbacks.append(checkpoint) + + # instantiate logger + print_only(f"Instantiating logger <{cfg.logger._target_}>") + os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True) + logger = hydra.utils.instantiate(cfg.logger) + + # instantiate trainer + print_only(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + strategy=DDPStrategy(find_unused_parameters=True), + ) + + trainer.fit(system, datamodule=datamodule) + print_only("Training finished!") + best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f: + json.dump(best_k, f, indent=0) + + state_dict = torch.load(checkpoint.best_model_path) + system.load_state_dict(state_dict=state_dict["state_dict"]) + system.cpu() + + to_save = system.audio_model.serialize() + torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth")) + import wandb + if wandb.run: + print_only("Closing wandb!") + wandb.finish() + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--conf_dir", + default="local/conf.yml", + help="Full path to save best validation model", + ) + + args = parser.parse_args() + cfg = OmegaConf.load(args.conf_dir) + + os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True) + OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml")) + + train(cfg) + \ No newline at end of file diff --git a/separation/audio_train.py b/separation/audio_train.py new file mode 100644 index 0000000..4de179f --- /dev/null +++ b/separation/audio_train.py @@ -0,0 +1,127 @@ +### +# Author: Kai Li +# Date: 2024-01-22 01:16:22 +# Email: lk21@mails.tsinghua.edu.cn +# LastEditTime: 2024-01-24 00:05:10 +### +import json +from typing import Any, Dict, List, Optional, Tuple +import os +from omegaconf import OmegaConf +import argparse +import pytorch_lightning as pl +import torch +torch.set_float32_matmul_precision("highest") +import hydra +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +# from pytorch_lightning.loggers import Logger +from omegaconf import DictConfig +import look2hear.system +import look2hear.datas +import look2hear.losses +from look2hear.utils import RankedLogger, instantiate, print_only +import warnings +warnings.filterwarnings("ignore") + + +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # instantiate datamodule + print_only(f"Instantiating datamodule <{cfg.datas._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas) + + # instantiate model + print_only(f"Instantiating AudioNet <{cfg.model._target_}>") + model: torch.nn.Module = hydra.utils.instantiate(cfg.model) + + # instantiate optimizer + print_only(f"Instantiating optimizer <{cfg.optimizer._target_}>") + optimizer: torch.optim = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) + # optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr) + + # instantiate scheduler + if cfg.get("scheduler"): + print_only(f"Instantiating scheduler <{cfg.scheduler._target_}>") + scheduler: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer) + else: + scheduler = None + + # instantiate loss + print_only(f"Instantiating loss <{cfg.loss._target_}>") + loss: torch.nn.Module = hydra.utils.instantiate(cfg.loss) + + # instantiate metrics + print_only(f"Instantiating metrics <{cfg.metrics._target_}>") + metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics) + # instantiate system + print_only(f"Instantiating system <{cfg.system._target_}>") + system: LightningModule = hydra.utils.instantiate( + cfg.system, + model=model, + loss_func=loss, + metrics=metrics, + optimizer=optimizer, + scheduler=scheduler + ) + + # instantiate callbacks + callbacks: List[Callback] = [] + if cfg.get("early_stopping"): + print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>") + callbacks.append(hydra.utils.instantiate(cfg.early_stopping)) + if cfg.get("checkpoint"): + print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>") + checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint) + callbacks.append(checkpoint) + + # instantiate logger + print_only(f"Instantiating logger <{cfg.logger._target_}>") + os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True) + logger = hydra.utils.instantiate(cfg.logger) + + # instantiate trainer + print_only(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + strategy=DDPStrategy(find_unused_parameters=True), + ) + + trainer.fit(system, datamodule=datamodule) + print_only("Training finished!") + best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f: + json.dump(best_k, f, indent=0) + + state_dict = torch.load(checkpoint.best_model_path) + system.load_state_dict(state_dict=state_dict["state_dict"]) + system.cpu() + + to_save = system.audio_model.serialize() + torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth")) + import wandb + if wandb.run: + print_only("Closing wandb!") + wandb.finish() + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--conf_dir", + default="local/conf.yml", + help="Full path to save best validation model", + ) + + args = parser.parse_args() + cfg = OmegaConf.load(args.conf_dir) + + os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True) + OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml")) + + train(cfg) + \ No newline at end of file