diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index e10df8c0..cd9439a4 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -813,8 +813,32 @@ def configure_optimizers( } optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params) - scheduler_params = cfg_scheduler.params | {"optimizer": optimizer} - scheduler = SCHEDULERS.get(cfg_scheduler.name)(**scheduler_params) + def get_scheduler(scheduler_cfg, optimizer): + scheduler_class = SCHEDULERS.get( + scheduler_cfg["name"] + ) # For dictionary access + scheduler_params = scheduler_cfg["params"] | { + "optimizer": optimizer + } # Dictionary access for params + return scheduler_class(**scheduler_params) + + if cfg_scheduler.name == "SequentialLR": + schedulers_list = [ + get_scheduler(scheduler_cfg, optimizer) + for scheduler_cfg in cfg_scheduler.params["schedulers"] + ] + + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=schedulers_list, + milestones=cfg_scheduler.params["milestones"], + ) + else: + scheduler_class = SCHEDULERS.get( + cfg_scheduler.name + ) # Access as attribute for single scheduler + scheduler_params = cfg_scheduler.params | {"optimizer": optimizer} + scheduler = scheduler_class(**scheduler_params) return [optimizer], [scheduler] diff --git a/tests/integration/test_scheduler.py b/tests/integration/test_scheduler.py new file mode 100644 index 00000000..ba3f4364 --- /dev/null +++ b/tests/integration/test_scheduler.py @@ -0,0 +1,69 @@ +import multiprocessing as mp + +import pytest +from luxonis_ml.data import LuxonisDataset + +from luxonis_train.core import LuxonisModel + + +def create_model_config(): + return { + "trainer": { + "n_sanity_val_steps": 0, + "preprocessing": {"train_image_size": [32, 32]}, + "epochs": 2, + "batch_size": 4, + "num_workers": mp.cpu_count(), + }, + "loader": { + "name": "LuxonisLoaderTorch", + "train_view": "train", + "params": {"dataset_name": "coco_test"}, + }, + "model": { + "name": "detection_light_model", + "predefined_model": { + "name": "DetectionModel", + "params": { + "variant": "light", + }, + }, + }, + } + + +def sequential_scheduler(): + return { + "name": "SequentialLR", + "params": { + "schedulers": [ + { + "name": "LinearLR", + "params": {"start_factor": 0.1, "total_iters": 10}, + }, + { + "name": "CosineAnnealingLR", + "params": {"T_max": 1, "eta_min": 0.01}, + }, + ], + "milestones": [1], + }, + } + + +def cosine_annealing_scheduler(): + return { + "name": "CosineAnnealingLR", + "params": {"T_max": 2, "eta_min": 0.001}, + } + + +@pytest.mark.parametrize( + "scheduler_config", [sequential_scheduler(), cosine_annealing_scheduler()] +) +def test_scheduler(coco_dataset: LuxonisDataset, scheduler_config): + config = create_model_config() + config["trainer"]["scheduler"] = scheduler_config + opts = {"loader.params.dataset_name": coco_dataset.dataset_name} + model = LuxonisModel(config, opts) + model.train()