Skip to content

Commit

Permalink
SequentialLR Scheduler (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Oct 18, 2024
1 parent af41af5 commit 9c0cb4c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
28 changes: 26 additions & 2 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,32 @@ def configure_optimizers(
}
optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params)

scheduler_params = cfg_scheduler.params | {"optimizer": optimizer}
scheduler = SCHEDULERS.get(cfg_scheduler.name)(**scheduler_params)
def get_scheduler(scheduler_cfg, optimizer):
scheduler_class = SCHEDULERS.get(
scheduler_cfg["name"]
) # For dictionary access
scheduler_params = scheduler_cfg["params"] | {
"optimizer": optimizer
} # Dictionary access for params
return scheduler_class(**scheduler_params)

if cfg_scheduler.name == "SequentialLR":
schedulers_list = [
get_scheduler(scheduler_cfg, optimizer)
for scheduler_cfg in cfg_scheduler.params["schedulers"]
]

scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=schedulers_list,
milestones=cfg_scheduler.params["milestones"],
)
else:
scheduler_class = SCHEDULERS.get(
cfg_scheduler.name
) # Access as attribute for single scheduler
scheduler_params = cfg_scheduler.params | {"optimizer": optimizer}
scheduler = scheduler_class(**scheduler_params)

return [optimizer], [scheduler]

Expand Down
69 changes: 69 additions & 0 deletions tests/integration/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import multiprocessing as mp

import pytest
from luxonis_ml.data import LuxonisDataset

from luxonis_train.core import LuxonisModel


def create_model_config():
return {
"trainer": {
"n_sanity_val_steps": 0,
"preprocessing": {"train_image_size": [32, 32]},
"epochs": 2,
"batch_size": 4,
"num_workers": mp.cpu_count(),
},
"loader": {
"name": "LuxonisLoaderTorch",
"train_view": "train",
"params": {"dataset_name": "coco_test"},
},
"model": {
"name": "detection_light_model",
"predefined_model": {
"name": "DetectionModel",
"params": {
"variant": "light",
},
},
},
}


def sequential_scheduler():
return {
"name": "SequentialLR",
"params": {
"schedulers": [
{
"name": "LinearLR",
"params": {"start_factor": 0.1, "total_iters": 10},
},
{
"name": "CosineAnnealingLR",
"params": {"T_max": 1, "eta_min": 0.01},
},
],
"milestones": [1],
},
}


def cosine_annealing_scheduler():
return {
"name": "CosineAnnealingLR",
"params": {"T_max": 2, "eta_min": 0.001},
}


@pytest.mark.parametrize(
"scheduler_config", [sequential_scheduler(), cosine_annealing_scheduler()]
)
def test_scheduler(coco_dataset: LuxonisDataset, scheduler_config):
config = create_model_config()
config["trainer"]["scheduler"] = scheduler_config
opts = {"loader.params.dataset_name": coco_dataset.dataset_name}
model = LuxonisModel(config, opts)
model.train()

0 comments on commit 9c0cb4c

Please sign in to comment.