diff --git a/models/tts/naturalspeech2/ns2_trainer.py b/models/tts/naturalspeech2/ns2_trainer.py index a523bf35..63c4353e 100644 --- a/models/tts/naturalspeech2/ns2_trainer.py +++ b/models/tts/naturalspeech2/ns2_trainer.py @@ -35,7 +35,6 @@ class NS2Trainer(TTSTrainer): def __init__(self, args, cfg): - self.args = args self.cfg = cfg @@ -355,7 +354,7 @@ def _build_dataloader(self): def _build_optimizer(self): optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, self.model.parameters()), - **self.cfg.train.adam + **self.cfg.train.adam, ) return optimizer @@ -796,4 +795,4 @@ def train_loop(self): ), ) ) - self.accelerator.end_training() \ No newline at end of file + self.accelerator.end_training()