diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index 64fbdf4f..f630f495 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -65,3 +65,16 @@ Callback to perform a test run at the end of the training. ## `UploadCheckpoint` Callback that uploads currently the best checkpoint (based on validation loss) to the tracker location - where all other logs are stored. + +## `EMACallback` + +Callback that updates the stored parameters using a moving average. + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ----------------------------------------------------------------------------------------------- | +| `decay` | `float` | `0.5` | Decay factor for the moving average. | +| `use_dynamic_decay` | `bool` | `True` | Whether to use dynamic decay. | +| `decay_tau` | `float` | `2000` | Decay tau for dynamic decay. | +| `device` | `str` | `None` | Device to use for the moving average. If `None`, the device is inferred from the model's device | diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index 95f860a1..a751e322 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -13,6 +13,7 @@ from luxonis_train.utils.registry import CALLBACKS from .archive_on_train_end import ArchiveOnTrainEnd +from .ema import EMACallback from .export_on_train_end import ExportOnTrainEnd from .gpu_stats_monitor import GPUStatsMonitor from .luxonis_progress_bar import ( @@ -34,6 +35,7 @@ CALLBACKS.register_module(module=StochasticWeightAveraging) CALLBACKS.register_module(module=Timer) CALLBACKS.register_module(module=ModelPruning) +CALLBACKS.register_module(module=EMACallback) __all__ = [ @@ -47,4 +49,5 @@ "TestOnTrainEnd", "UploadCheckpoint", "GPUStatsMonitor", + "EMACallback", ] diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py new file mode 100644 index 00000000..20c01c04 --- /dev/null +++ b/luxonis_train/callbacks/ema.py @@ -0,0 +1,249 @@ +import logging +import math +from copy import deepcopy +from typing import Any + +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import nn + +logger = logging.getLogger(__name__) + + +class ModelEma(nn.Module): + """Model Exponential Moving Average. + + Keeps a moving average of everything in the model.state_dict + (parameters and buffers). + """ + + def __init__( + self, + model: pl.LightningModule, + decay: float = 0.9999, + use_dynamic_decay: bool = True, + decay_tau: float = 2000, + device: str | None = None, + ): + """Constructs `ModelEma`. + + @type model: L{pl.LightningModule} + @param model: Pytorch Lightning module. + @type decay: float + @param decay: Decay rate for the moving average. + @type use_dynamic_decay: bool + @param use_dynamic_decay: Use dynamic decay rate. + @type decay_tau: float + @param decay_tau: Decay tau for the moving average. + @type device: str | None + @param device: Device to perform EMA on. + """ + super(ModelEma, self).__init__() + model.eval() + self.state_dict_ema = deepcopy(model.state_dict()) + model.train() + + for p in self.state_dict_ema.values(): + p.requires_grad = False + self.updates = 0 + self.decay = decay + self.use_dynamic_decay = use_dynamic_decay + self.decay_tau = decay_tau + self.device = device + if self.device is not None: + self.state_dict_ema = { + k: v.to(device=device) for k, v in self.state_dict_ema.items() + } + + def update(self, model: pl.LightningModule) -> None: + """Update the stored parameters using a moving average. + + Source: U{} + + @license: U{Apache License 2.0} + + @type model: L{pl.LightningModule} + @param model: Pytorch Lightning module. + """ + with torch.no_grad(): + self.updates += 1 + + if self.use_dynamic_decay: + decay = self.decay * ( + 1 - math.exp(-self.updates / self.decay_tau) + ) + else: + decay = self.decay + + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip( + self.state_dict_ema.values(), model.state_dict().values() + ): + if ema_v.is_floating_point(): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) + else: + ema_v.copy_(model_v) + + if hasattr(torch, "_foreach_lerp_"): + torch._foreach_lerp_( + ema_lerp_values, model_lerp_values, weight=1.0 - decay + ) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_( + ema_lerp_values, model_lerp_values, alpha=1.0 - decay + ) + + +class EMACallback(Callback): + """Callback that updates the stored parameters using a moving + average.""" + + def __init__( + self, + decay: float = 0.5, + use_dynamic_decay: bool = True, + decay_tau: float = 2000, + device: str | None = None, + ): + """Constructs `EMACallback`. + + @type decay: float + @param decay: Decay rate for the moving average. + @type use_dynamic_decay: bool + @param use_dynamic_decay: Use dynamic decay rate. If True, the + decay rate will be updated based on the number of updates. + @type decay_tau: float + @param decay_tau: Decay tau for the moving average. + @type device: str | None + @param device: Device to perform EMA on. + """ + self.decay = decay + self.use_dynamic_decay = use_dynamic_decay + self.decay_tau = decay_tau + self.device = device + + self.ema = None + self.collected_state_dict = None + + def on_fit_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Initialize `ModelEma` to keep a copy of the moving average of + the weights. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + + self.ema = ModelEma( + pl_module, + decay=self.decay, + use_dynamic_decay=self.use_dynamic_decay, + decay_tau=self.decay_tau, + device=self.device, + ) + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Update the stored parameters using a moving average. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + @type outputs: Any + @param outputs: Outputs from the training step. + @type batch: Any + @param batch: Batch data. + @type batch_idx: int + @param batch_idx: Batch index. + """ + if batch_idx % trainer.accumulate_grad_batches == 0: + if self.ema is not None: + self.ema.update(pl_module) + + def on_validation_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Do validation using the stored parameters. Save the original + parameters before replacing with EMA version. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + + self.collected_state_dict = deepcopy(pl_module.state_dict()) + + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) + + def on_validation_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Restore original parameters to resume training later. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + if self.collected_state_dict is not None: + pl_module.load_state_dict(self.collected_state_dict) + + def on_train_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Update the LightningModule with the EMA weights. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) + logger.info("Model weights replaced with the EMA weights.") + + def on_save_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: dict, + ) -> None: # or dict? + """Save the EMA state_dict to the checkpoint. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + @type checkpoint: dict + @param checkpoint: Pytorch Lightning checkpoint. + """ + if self.ema is not None: + checkpoint["state_dict"] = self.ema.state_dict_ema + + def on_load_checkpoint(self, callback_state: dict) -> None: + """Load the EMA state_dict from the checkpoint. + + @type callback_state: dict + @param callback_state: Pytorch Lightning callback state. + """ + if self.ema is not None: + self.ema.state_dict_ema = callback_state["state_dict"] diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 604b9c31..7669ed39 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -411,6 +411,13 @@ def check_validation_interval(self) -> Self: self.validation_interval = self.epochs return self + @model_validator(mode="after") + def reorder_callbacks(self) -> Self: + """Reorder callbacks so that EMA is the first callback, since it + needs to be updated before other callbacks.""" + self.callbacks.sort(key=lambda v: 0 if v.name == "EMACallback" else 1) + return self + class OnnxExportConfig(BaseModelExtraForbid): opset_version: PositiveInt = 12 diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py new file mode 100644 index 00000000..0780e783 --- /dev/null +++ b/tests/unittests/test_callbacks/test_ema.py @@ -0,0 +1,101 @@ +from copy import deepcopy + +import pytest +import torch +from pytorch_lightning import LightningModule, Trainer + +from luxonis_train.callbacks.ema import EMACallback, ModelEma + + +class SimpleModel(LightningModule): + def __init__(self): + super(SimpleModel, self).__init__() + self.layer = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.layer(x) + + +@pytest.fixture +def model(): + return SimpleModel() + + +@pytest.fixture +def ema_callback(): + return EMACallback() + + +def test_ema_initialization(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + assert isinstance(ema_callback.ema, ModelEma) + assert ema_callback.ema.decay == ema_callback.decay + assert ema_callback.ema.use_dynamic_decay == ema_callback.use_dynamic_decay + assert ema_callback.ema.device == ema_callback.device + + +def test_ema_update_on_batch_end(model, ema_callback): + trainer = Trainer() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + ema_callback.on_fit_start(trainer, model) + + initial_ema_state = { + k: v.clone() for k, v in ema_callback.ema.state_dict_ema.items() + } + + batch = torch.rand(2, 2) + batch_idx = 0 + + model.train() + outputs = model(batch) + model.zero_grad() + outputs.sum().backward() + optimizer.step() + + ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) + + # Check that the EMA has been updated + updated_state = ema_callback.ema.state_dict_ema + assert any( + not torch.equal(initial_ema_state[k], updated_state[k]) + for k in initial_ema_state + ) + + +def test_ema_state_saved_to_checkpoint(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + checkpoint = {} + ema_callback.on_save_checkpoint(trainer, model, checkpoint) + + assert "state_dict" in checkpoint or "state_dict_ema" in checkpoint + + +def test_load_from_checkpoint(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + checkpoint = {"state_dict": deepcopy(model.state_dict())} + ema_callback.on_load_checkpoint(checkpoint) + + assert ( + ema_callback.ema.state_dict_ema.keys() + == checkpoint["state_dict"].keys() + ) + + +def test_validation_epoch_start_and_end(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + ema_callback.on_validation_epoch_start(trainer, model) + assert ema_callback.collected_state_dict is not None + + ema_callback.on_validation_end(trainer, model) + for k in ema_callback.collected_state_dict.keys(): + assert torch.equal( + ema_callback.collected_state_dict[k], model.state_dict()[k] + )