Skip to content

Commit

Permalink
fix cosine_schedule_with_warmup (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
HeCheng0625 authored Dec 23, 2023
1 parent 96e028c commit 001ebab
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions config/valle.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"optimizer": "AdamW",
"scheduler": "cosine",
"warmup_steps": 16000, // number of steps that affects how rapidly the learning rate decreases
"total_training_steps": 800000,
"base_lr": 1e-4, // base learning rate."
"valid_interval": 1000,
"log_epoch_step": 1000,
Expand Down
2 changes: 1 addition & 1 deletion egs/tts/VALLE/exp_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"use_acoustic_token": true,
"processed_dir": "Amphion/data/",
"sample_rate": 24000, // "Audio sampling rate."
"codec_hop_size": "320", // "Audio codec hop size."
"codec_hop_size": 320, // "Audio codec hop size."
"valid_file": "test.json",
},
"model": {
Expand Down
11 changes: 8 additions & 3 deletions models/tts/valle/valle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from models.tts.base import TTSTrainer
from models.tts.valle.valle import VALLE
import diffusers


class VALLETrainer(TTSTrainer):
Expand Down Expand Up @@ -108,10 +109,14 @@ def _build_scheduler(self):
warmup_steps=self.cfg.train.warmup_steps,
)
elif self.cfg.train.scheduler.lower() == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.cfg.train.warmup_steps,
from diffusers.optimization import get_cosine_schedule_with_warmup

scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
eta_min=self.cfg.train.base_lr,
num_warmup_steps=self.cfg.train.warmup_steps
* self.accelerator.num_processes,
num_training_steps=self.cfg.train.total_training_steps
* self.accelerator.num_processes,
)
else:
raise NotImplementedError(f"{self.cfg.train.scheduler}")
Expand Down

0 comments on commit 001ebab

Please sign in to comment.