From 4c9449ba1edca9002337f5aa86bc217591b3e99f Mon Sep 17 00:00:00 2001 From: Nikita Date: Sun, 13 Oct 2024 13:33:43 +0000 Subject: [PATCH 01/17] feat: add EMACallback --- luxonis_train/callbacks/__init__.py | 3 + luxonis_train/callbacks/ema.py | 189 ++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 luxonis_train/callbacks/ema.py diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index 95f860a1..a751e322 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -13,6 +13,7 @@ from luxonis_train.utils.registry import CALLBACKS from .archive_on_train_end import ArchiveOnTrainEnd +from .ema import EMACallback from .export_on_train_end import ExportOnTrainEnd from .gpu_stats_monitor import GPUStatsMonitor from .luxonis_progress_bar import ( @@ -34,6 +35,7 @@ CALLBACKS.register_module(module=StochasticWeightAveraging) CALLBACKS.register_module(module=Timer) CALLBACKS.register_module(module=ModelPruning) +CALLBACKS.register_module(module=EMACallback) __all__ = [ @@ -47,4 +49,5 @@ "TestOnTrainEnd", "UploadCheckpoint", "GPUStatsMonitor", + "EMACallback", ] diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py new file mode 100644 index 00000000..ab1c8ec4 --- /dev/null +++ b/luxonis_train/callbacks/ema.py @@ -0,0 +1,189 @@ +import logging +from typing import Any, List, Tuple, Union + +import pytorch_lightning as pl +import torch +from copy import deepcopy +from torch import nn +from torch import Tensor +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.types import STEP_OUTPUT +import math + + +logger = logging.getLogger(__name__) + + +class ModelEma(nn.Module): + """Model Exponential Moving Average. + Keeps a moving average of everything in the model.state_dict (parameters and buffers). + """ + + def __init__(self, model: pl.LightningModule, decay: float = 0.9999, use_dynamic_decay: bool = True, device: str = None): + """Constructs `ModelEma`. + + @type model: L{pl.LightningModule} + @param model: Pytorch Lightning module. + @type decay: float + @param decay: Decay rate for the moving average. + @type use_dynamic_decay: bool + @param use_dynamic_decay: Use dynamic decay rate. + @type device: str + @param device: Device to perform EMA on. + """ + super(ModelEma, self).__init__() + model.eval() + self.state_dict = deepcopy(model.state_dict()) + model.train() + + for p in self.state_dict.values(): + p.requires_grad = False + self.updates = 0 + self.decay = decay + self.use_dynamic_decay = use_dynamic_decay + self.device = device + if self.device is not None: + self.state_dict = {k: v.to(device=device) for k, v in self.state_dict.items()} + + def update(self, model: pl.LightningModule) -> None: + """Update the stored parameters using a moving average + + @type model: L{pl.LightningModule} + @param model: Pytorch Lightning module. + """ + with torch.no_grad(): + for k, ema_p in self.state_dict.items(): + if ema_p.dtype.is_floating_point: + self.updates += 1 + + if self.use_dynamic_decay: + decay = self.decay * (1 - math.exp(-self.updates / 2000)) + else: + decay = self.decay + + model_p = model.state_dict()[k] + if self.device is not None: + model_p = model_p.to(device=self.device) + ema_p *= decay + ema_p += (1. - decay) * model_p + +class EMACallback(Callback): + """ + Callback that updates the stored parameters using a moving average. + """ + + def __init__(self, decay: float = 0.9999, use_dynamic_decay: bool = True, use_ema_weights: bool = True, device: str = None): + """Constructs `EMACallback`. + + @type decay: float + @param decay: Decay rate for the moving average. + @type use_dynamic_decay: bool + @param use_dynamic_decay: Use dynamic decay rate. If True, the decay rate will be updated based on the number of updates. + @type use_ema_weights: bool + @param use_ema_weights: Use EMA weights (replace model weights with EMA weights) + @type device: str + @param device: Device to perform EMA on. + """ + self.decay = decay + self.use_dynamic_decay = use_dynamic_decay + self.use_ema_weights = use_ema_weights + self.device = device + + self.ema = None + self.collected_state_dict = None + + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Initialize `ModelEma` to keep a copy of the moving average of the weights + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + + self.ema = ModelEma(pl_module, decay=self.decay, use_dynamic_decay = self.use_dynamic_decay, device=self.device) + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Update the stored parameters using a moving average + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + @type outputs: Any + @param outputs: Outputs from the training step. + @type batch: Any + @param batch: Batch data. + @type batch_idx: int + @param batch_idx: Batch index. + """ + + self.ema.update(pl_module) + + def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Do validation using the stored parameters. Save the original parameters before replacing with EMA version. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + + self.collected_state_dict = deepcopy(pl_module.state_dict()) + + pl_module.load_state_dict(self.ema.state_dict) + + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Restore original parameters to resume training later + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + pl_module.load_state_dict(self.collected_state_dict) + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Update the LightningModule with the EMA weights + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + if self.use_ema_weights: + pl_module.load_state_dict(self.ema.state_dict) + logger.info("Model weights replaced with the EMA weights.") + + def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None: # or dict? + """Save the EMA state_dict to the checkpoint + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + @type checkpoint: dict + @param checkpoint: Pytorch Lightning checkpoint. + """ + if self.use_ema_weights: + checkpoint["state_dict"] = self.ema.state_dict + elif self.ema is not None: + checkpoint["state_dict_ema"] = self.ema.state_dict + + def on_load_checkpoint(self, callback_state: dict) -> None: + """Load the EMA state_dict from the checkpoint + + @type callback_state: dict + @param callback_state: Pytorch Lightning callback state. + """ + if self.use_ema_weights: + self.ema.state_dict = callback_state["state_dict"] + elif self.ema is not None: + self.ema.state_dict = callback_state["state_dict_ema"] \ No newline at end of file From 5a023742927365a926a3bf4bfeb4c6c860592085 Mon Sep 17 00:00:00 2001 From: Nikita Date: Sun, 13 Oct 2024 13:34:14 +0000 Subject: [PATCH 02/17] test: EMACallbakc test --- tests/unittests/test_callbacks/test_ema.py | 79 ++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/unittests/test_callbacks/test_ema.py diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py new file mode 100644 index 00000000..5dd84aef --- /dev/null +++ b/tests/unittests/test_callbacks/test_ema.py @@ -0,0 +1,79 @@ +import pytest +import torch +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.utilities.types import STEP_OUTPUT +from unittest.mock import MagicMock +from copy import deepcopy + +from luxonis_train.callbacks.ema import EMACallback, ModelEma + +class SimpleModel(LightningModule): + def __init__(self): + super(SimpleModel, self).__init__() + self.layer = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.layer(x) + +@pytest.fixture +def model(): + return SimpleModel() + +@pytest.fixture +def ema_callback(): + return EMACallback() + +def test_ema_initialization(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + assert isinstance(ema_callback.ema, ModelEma) + assert ema_callback.ema.decay == ema_callback.decay + assert ema_callback.ema.use_dynamic_decay == ema_callback.use_dynamic_decay + assert ema_callback.ema.device == ema_callback.device + +def test_ema_update_on_batch_end(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + initial_ema_state = {k: v.clone() for k, v in ema_callback.ema.state_dict.items()} + + outputs = None # Use a dummy output + batch = torch.rand(2, 2) + batch_idx = 0 + + ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) + + # Check that the EMA has been updated + updated_state = ema_callback.ema.state_dict + assert any(not torch.equal(initial_ema_state[k], updated_state[k]) for k in initial_ema_state) + +def test_ema_state_saved_to_checkpoint(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + checkpoint = {} + ema_callback.on_save_checkpoint(trainer, model, checkpoint) + + assert "state_dict" in checkpoint or "state_dict_ema" in checkpoint + +def test_load_from_checkpoint(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + checkpoint = {"state_dict": deepcopy(model.state_dict())} + ema_callback.on_load_checkpoint(checkpoint) + + assert ema_callback.ema.state_dict.keys() == checkpoint["state_dict"].keys() + +def test_validation_epoch_start_and_end(model, ema_callback): + trainer = Trainer() + ema_callback.on_fit_start(trainer, model) + + ema_callback.on_validation_epoch_start(trainer, model) + assert ema_callback.collected_state_dict is not None + + ema_callback.on_validation_end(trainer, model) + for k in ema_callback.collected_state_dict.keys(): + assert torch.equal(ema_callback.collected_state_dict[k], model.state_dict()[k]) + From cde31765a1c5a3e211cfdb43826bf3f210309a26 Mon Sep 17 00:00:00 2001 From: Nikita Date: Sun, 13 Oct 2024 13:34:55 +0000 Subject: [PATCH 03/17] style: formatting --- luxonis_train/callbacks/ema.py | 106 ++++++++++++++------- tests/unittests/test_callbacks/test_ema.py | 39 +++++--- 2 files changed, 99 insertions(+), 46 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index ab1c8ec4..56343878 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -1,25 +1,31 @@ import logging -from typing import Any, List, Tuple, Union +import math +from copy import deepcopy +from typing import Any import pytorch_lightning as pl import torch -from copy import deepcopy -from torch import nn -from torch import Tensor from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.types import STEP_OUTPUT -import math - +from torch import nn logger = logging.getLogger(__name__) class ModelEma(nn.Module): """Model Exponential Moving Average. - Keeps a moving average of everything in the model.state_dict (parameters and buffers). + + Keeps a moving average of everything in the model.state_dict + (parameters and buffers). """ - def __init__(self, model: pl.LightningModule, decay: float = 0.9999, use_dynamic_decay: bool = True, device: str = None): + def __init__( + self, + model: pl.LightningModule, + decay: float = 0.9999, + use_dynamic_decay: bool = True, + device: str = None, + ): """Constructs `ModelEma`. @type model: L{pl.LightningModule} @@ -43,10 +49,12 @@ def __init__(self, model: pl.LightningModule, decay: float = 0.9999, use_dynamic self.use_dynamic_decay = use_dynamic_decay self.device = device if self.device is not None: - self.state_dict = {k: v.to(device=device) for k, v in self.state_dict.items()} + self.state_dict = { + k: v.to(device=device) for k, v in self.state_dict.items() + } def update(self, model: pl.LightningModule) -> None: - """Update the stored parameters using a moving average + """Update the stored parameters using a moving average. @type model: L{pl.LightningModule} @param model: Pytorch Lightning module. @@ -57,7 +65,9 @@ def update(self, model: pl.LightningModule) -> None: self.updates += 1 if self.use_dynamic_decay: - decay = self.decay * (1 - math.exp(-self.updates / 2000)) + decay = self.decay * ( + 1 - math.exp(-self.updates / 2000) + ) else: decay = self.decay @@ -65,22 +75,30 @@ def update(self, model: pl.LightningModule) -> None: if self.device is not None: model_p = model_p.to(device=self.device) ema_p *= decay - ema_p += (1. - decay) * model_p + ema_p += (1.0 - decay) * model_p + class EMACallback(Callback): - """ - Callback that updates the stored parameters using a moving average. - """ + """Callback that updates the stored parameters using a moving + average.""" - def __init__(self, decay: float = 0.9999, use_dynamic_decay: bool = True, use_ema_weights: bool = True, device: str = None): + def __init__( + self, + decay: float = 0.9999, + use_dynamic_decay: bool = True, + use_ema_weights: bool = True, + device: str = None, + ): """Constructs `EMACallback`. @type decay: float @param decay: Decay rate for the moving average. @type use_dynamic_decay: bool - @param use_dynamic_decay: Use dynamic decay rate. If True, the decay rate will be updated based on the number of updates. + @param use_dynamic_decay: Use dynamic decay rate. If True, the + decay rate will be updated based on the number of updates. @type use_ema_weights: bool - @param use_ema_weights: Use EMA weights (replace model weights with EMA weights) + @param use_ema_weights: Use EMA weights (replace model weights + with EMA weights) @type device: str @param device: Device to perform EMA on. """ @@ -92,8 +110,11 @@ def __init__(self, decay: float = 0.9999, use_dynamic_decay: bool = True, use_em self.ema = None self.collected_state_dict = None - def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Initialize `ModelEma` to keep a copy of the moving average of the weights + def on_fit_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Initialize `ModelEma` to keep a copy of the moving average of + the weights. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -101,7 +122,12 @@ def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No @param pl_module: Pytorch Lightning module. """ - self.ema = ModelEma(pl_module, decay=self.decay, use_dynamic_decay = self.use_dynamic_decay, device=self.device) + self.ema = ModelEma( + pl_module, + decay=self.decay, + use_dynamic_decay=self.use_dynamic_decay, + device=self.device, + ) def on_train_batch_end( self, @@ -111,7 +137,7 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: - """Update the stored parameters using a moving average + """Update the stored parameters using a moving average. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -124,11 +150,14 @@ def on_train_batch_end( @type batch_idx: int @param batch_idx: Batch index. """ - + self.ema.update(pl_module) - def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Do validation using the stored parameters. Save the original parameters before replacing with EMA version. + def on_validation_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Do validation using the stored parameters. Save the original + parameters before replacing with EMA version. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -140,8 +169,10 @@ def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.Lightning pl_module.load_state_dict(self.ema.state_dict) - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Restore original parameters to resume training later + def on_validation_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Restore original parameters to resume training later. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -150,8 +181,10 @@ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) """ pl_module.load_state_dict(self.collected_state_dict) - def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Update the LightningModule with the EMA weights + def on_train_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Update the LightningModule with the EMA weights. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -162,8 +195,13 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No pl_module.load_state_dict(self.ema.state_dict) logger.info("Model weights replaced with the EMA weights.") - def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None: # or dict? - """Save the EMA state_dict to the checkpoint + def on_save_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: dict, + ) -> None: # or dict? + """Save the EMA state_dict to the checkpoint. @type trainer: L{pl.Trainer} @param trainer: Pytorch Lightning trainer. @@ -178,12 +216,12 @@ def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint["state_dict_ema"] = self.ema.state_dict def on_load_checkpoint(self, callback_state: dict) -> None: - """Load the EMA state_dict from the checkpoint - + """Load the EMA state_dict from the checkpoint. + @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ if self.use_ema_weights: self.ema.state_dict = callback_state["state_dict"] elif self.ema is not None: - self.ema.state_dict = callback_state["state_dict_ema"] \ No newline at end of file + self.ema.state_dict = callback_state["state_dict_ema"] diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index 5dd84aef..c6876ae2 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -1,12 +1,12 @@ +from copy import deepcopy + import pytest import torch -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.types import STEP_OUTPUT -from unittest.mock import MagicMock -from copy import deepcopy +from pytorch_lightning import LightningModule, Trainer from luxonis_train.callbacks.ema import EMACallback, ModelEma + class SimpleModel(LightningModule): def __init__(self): super(SimpleModel, self).__init__() @@ -15,38 +15,48 @@ def __init__(self): def forward(self, x): return self.layer(x) + @pytest.fixture def model(): return SimpleModel() + @pytest.fixture def ema_callback(): return EMACallback() + def test_ema_initialization(model, ema_callback): trainer = Trainer() ema_callback.on_fit_start(trainer, model) - + assert isinstance(ema_callback.ema, ModelEma) assert ema_callback.ema.decay == ema_callback.decay assert ema_callback.ema.use_dynamic_decay == ema_callback.use_dynamic_decay assert ema_callback.ema.device == ema_callback.device + def test_ema_update_on_batch_end(model, ema_callback): trainer = Trainer() ema_callback.on_fit_start(trainer, model) - initial_ema_state = {k: v.clone() for k, v in ema_callback.ema.state_dict.items()} - + initial_ema_state = { + k: v.clone() for k, v in ema_callback.ema.state_dict.items() + } + outputs = None # Use a dummy output batch = torch.rand(2, 2) batch_idx = 0 - + ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) # Check that the EMA has been updated updated_state = ema_callback.ema.state_dict - assert any(not torch.equal(initial_ema_state[k], updated_state[k]) for k in initial_ema_state) + assert any( + not torch.equal(initial_ema_state[k], updated_state[k]) + for k in initial_ema_state + ) + def test_ema_state_saved_to_checkpoint(model, ema_callback): trainer = Trainer() @@ -57,6 +67,7 @@ def test_ema_state_saved_to_checkpoint(model, ema_callback): assert "state_dict" in checkpoint or "state_dict_ema" in checkpoint + def test_load_from_checkpoint(model, ema_callback): trainer = Trainer() ema_callback.on_fit_start(trainer, model) @@ -64,7 +75,10 @@ def test_load_from_checkpoint(model, ema_callback): checkpoint = {"state_dict": deepcopy(model.state_dict())} ema_callback.on_load_checkpoint(checkpoint) - assert ema_callback.ema.state_dict.keys() == checkpoint["state_dict"].keys() + assert ( + ema_callback.ema.state_dict.keys() == checkpoint["state_dict"].keys() + ) + def test_validation_epoch_start_and_end(model, ema_callback): trainer = Trainer() @@ -75,5 +89,6 @@ def test_validation_epoch_start_and_end(model, ema_callback): ema_callback.on_validation_end(trainer, model) for k in ema_callback.collected_state_dict.keys(): - assert torch.equal(ema_callback.collected_state_dict[k], model.state_dict()[k]) - + assert torch.equal( + ema_callback.collected_state_dict[k], model.state_dict()[k] + ) From e26d799d99e6f44720434b7c7a92fb61ea8b807c Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 12:50:59 +0000 Subject: [PATCH 04/17] feat: add decay_tau argument --- luxonis_train/callbacks/ema.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 56343878..f3acd7d0 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -24,6 +24,7 @@ def __init__( model: pl.LightningModule, decay: float = 0.9999, use_dynamic_decay: bool = True, + decay_tau: float = 2000, device: str = None, ): """Constructs `ModelEma`. @@ -34,6 +35,8 @@ def __init__( @param decay: Decay rate for the moving average. @type use_dynamic_decay: bool @param use_dynamic_decay: Use dynamic decay rate. + @type decay_tau: float + @param decay_tau: Decay tau for the moving average. @type device: str @param device: Device to perform EMA on. """ @@ -47,6 +50,7 @@ def __init__( self.updates = 0 self.decay = decay self.use_dynamic_decay = use_dynamic_decay + self.decay_tau = decay_tau self.device = device if self.device is not None: self.state_dict = { @@ -66,7 +70,7 @@ def update(self, model: pl.LightningModule) -> None: if self.use_dynamic_decay: decay = self.decay * ( - 1 - math.exp(-self.updates / 2000) + 1 - math.exp(-self.updates / self.decay_tau) ) else: decay = self.decay @@ -84,8 +88,9 @@ class EMACallback(Callback): def __init__( self, - decay: float = 0.9999, + decay: float = 0.5, use_dynamic_decay: bool = True, + decay_tau: float = 2000, use_ema_weights: bool = True, device: str = None, ): @@ -96,6 +101,8 @@ def __init__( @type use_dynamic_decay: bool @param use_dynamic_decay: Use dynamic decay rate. If True, the decay rate will be updated based on the number of updates. + @type decay_tau: float + @param decay_tau: Decay tau for the moving average. @type use_ema_weights: bool @param use_ema_weights: Use EMA weights (replace model weights with EMA weights) @@ -104,6 +111,7 @@ def __init__( """ self.decay = decay self.use_dynamic_decay = use_dynamic_decay + self.decay_tau = decay_tau self.use_ema_weights = use_ema_weights self.device = device @@ -126,6 +134,7 @@ def on_fit_start( pl_module, decay=self.decay, use_dynamic_decay=self.use_dynamic_decay, + decay_tau=self.decay_tau, device=self.device, ) From e857e9439cd2f032f9fe762871184965b498cc70 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:01:12 +0000 Subject: [PATCH 05/17] feat: always replace original weights with EMA weights --- luxonis_train/callbacks/ema.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index f3acd7d0..ee9a8cd8 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -91,7 +91,6 @@ def __init__( decay: float = 0.5, use_dynamic_decay: bool = True, decay_tau: float = 2000, - use_ema_weights: bool = True, device: str = None, ): """Constructs `EMACallback`. @@ -103,16 +102,12 @@ def __init__( decay rate will be updated based on the number of updates. @type decay_tau: float @param decay_tau: Decay tau for the moving average. - @type use_ema_weights: bool - @param use_ema_weights: Use EMA weights (replace model weights - with EMA weights) @type device: str @param device: Device to perform EMA on. """ self.decay = decay self.use_dynamic_decay = use_dynamic_decay self.decay_tau = decay_tau - self.use_ema_weights = use_ema_weights self.device = device self.ema = None @@ -200,9 +195,8 @@ def on_train_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - if self.use_ema_weights: - pl_module.load_state_dict(self.ema.state_dict) - logger.info("Model weights replaced with the EMA weights.") + pl_module.load_state_dict(self.ema.state_dict) + logger.info("Model weights replaced with the EMA weights.") def on_save_checkpoint( self, @@ -219,10 +213,7 @@ def on_save_checkpoint( @type checkpoint: dict @param checkpoint: Pytorch Lightning checkpoint. """ - if self.use_ema_weights: - checkpoint["state_dict"] = self.ema.state_dict - elif self.ema is not None: - checkpoint["state_dict_ema"] = self.ema.state_dict + checkpoint["state_dict"] = self.ema.state_dict def on_load_checkpoint(self, callback_state: dict) -> None: """Load the EMA state_dict from the checkpoint. @@ -230,7 +221,4 @@ def on_load_checkpoint(self, callback_state: dict) -> None: @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ - if self.use_ema_weights: - self.ema.state_dict = callback_state["state_dict"] - elif self.ema is not None: - self.ema.state_dict = callback_state["state_dict_ema"] + self.ema.state_dict = callback_state["state_dict"] From 0f1c94750092f9057692758cea244959fc3d988b Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:34:20 +0000 Subject: [PATCH 06/17] fix: type isssues --- luxonis_train/callbacks/ema.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index ee9a8cd8..091d23ad 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -25,7 +25,7 @@ def __init__( decay: float = 0.9999, use_dynamic_decay: bool = True, decay_tau: float = 2000, - device: str = None, + device: str | None = None, ): """Constructs `ModelEma`. @@ -37,15 +37,15 @@ def __init__( @param use_dynamic_decay: Use dynamic decay rate. @type decay_tau: float @param decay_tau: Decay tau for the moving average. - @type device: str + @type device: str | None @param device: Device to perform EMA on. """ super(ModelEma, self).__init__() model.eval() - self.state_dict = deepcopy(model.state_dict()) + self.state_dict_ema = deepcopy(model.state_dict()) model.train() - for p in self.state_dict.values(): + for p in self.state_dict_ema.values(): p.requires_grad = False self.updates = 0 self.decay = decay @@ -53,8 +53,8 @@ def __init__( self.decay_tau = decay_tau self.device = device if self.device is not None: - self.state_dict = { - k: v.to(device=device) for k, v in self.state_dict.items() + self.state_dict_ema = { + k: v.to(device=device) for k, v in self.state_dict_ema.items() } def update(self, model: pl.LightningModule) -> None: @@ -64,7 +64,7 @@ def update(self, model: pl.LightningModule) -> None: @param model: Pytorch Lightning module. """ with torch.no_grad(): - for k, ema_p in self.state_dict.items(): + for k, ema_p in self.state_dict_ema.items(): if ema_p.dtype.is_floating_point: self.updates += 1 @@ -91,7 +91,7 @@ def __init__( decay: float = 0.5, use_dynamic_decay: bool = True, decay_tau: float = 2000, - device: str = None, + device: str | None = None, ): """Constructs `EMACallback`. @@ -102,7 +102,7 @@ def __init__( decay rate will be updated based on the number of updates. @type decay_tau: float @param decay_tau: Decay tau for the moving average. - @type device: str + @type device: str | None @param device: Device to perform EMA on. """ self.decay = decay @@ -155,7 +155,7 @@ def on_train_batch_end( @param batch_idx: Batch index. """ - self.ema.update(pl_module) + self.ema.update(pl_module) # type: ignore def on_validation_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -171,7 +171,7 @@ def on_validation_epoch_start( self.collected_state_dict = deepcopy(pl_module.state_dict()) - pl_module.load_state_dict(self.ema.state_dict) + pl_module.load_state_dict(self.ema.state_dict_ema) def on_validation_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -195,7 +195,7 @@ def on_train_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - pl_module.load_state_dict(self.ema.state_dict) + pl_module.load_state_dict(self.ema.state_dict_ema) logger.info("Model weights replaced with the EMA weights.") def on_save_checkpoint( @@ -213,7 +213,7 @@ def on_save_checkpoint( @type checkpoint: dict @param checkpoint: Pytorch Lightning checkpoint. """ - checkpoint["state_dict"] = self.ema.state_dict + checkpoint["state_dict"] = self.ema.state_dict_ema def on_load_checkpoint(self, callback_state: dict) -> None: """Load the EMA state_dict from the checkpoint. @@ -221,4 +221,4 @@ def on_load_checkpoint(self, callback_state: dict) -> None: @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ - self.ema.state_dict = callback_state["state_dict"] + self.ema.state_dict_ema = callback_state["state_dict"] From aec55544619ebf171bb519500ca00ded89be615e Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:34:54 +0000 Subject: [PATCH 07/17] style: formatting --- luxonis_train/callbacks/ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 091d23ad..254322aa 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -155,7 +155,7 @@ def on_train_batch_end( @param batch_idx: Batch index. """ - self.ema.update(pl_module) # type: ignore + self.ema.update(pl_module) # type: ignore def on_validation_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule From c875e6ecf0ea871edbe86b88d1a1db3808fef805 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:36:33 +0000 Subject: [PATCH 08/17] test: adjust tests --- tests/unittests/test_callbacks/test_ema.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index c6876ae2..21eee4b1 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -41,7 +41,7 @@ def test_ema_update_on_batch_end(model, ema_callback): ema_callback.on_fit_start(trainer, model) initial_ema_state = { - k: v.clone() for k, v in ema_callback.ema.state_dict.items() + k: v.clone() for k, v in ema_callback.ema.state_dict_ema.items() } outputs = None # Use a dummy output @@ -51,7 +51,7 @@ def test_ema_update_on_batch_end(model, ema_callback): ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) # Check that the EMA has been updated - updated_state = ema_callback.ema.state_dict + updated_state = ema_callback.ema.state_dict_ema assert any( not torch.equal(initial_ema_state[k], updated_state[k]) for k in initial_ema_state @@ -76,7 +76,8 @@ def test_load_from_checkpoint(model, ema_callback): ema_callback.on_load_checkpoint(checkpoint) assert ( - ema_callback.ema.state_dict.keys() == checkpoint["state_dict"].keys() + ema_callback.ema.state_dict_ema.keys() + == checkpoint["state_dict"].keys() ) From 22bba412fbfe5a02ef8b568324ec7c2f7c5ace18 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:39:31 +0000 Subject: [PATCH 09/17] docs: add EMACallback to README --- luxonis_train/callbacks/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index 64fbdf4f..f630f495 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -65,3 +65,16 @@ Callback to perform a test run at the end of the training. ## `UploadCheckpoint` Callback that uploads currently the best checkpoint (based on validation loss) to the tracker location - where all other logs are stored. + +## `EMACallback` + +Callback that updates the stored parameters using a moving average. + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ----------------------------------------------------------------------------------------------- | +| `decay` | `float` | `0.5` | Decay factor for the moving average. | +| `use_dynamic_decay` | `bool` | `True` | Whether to use dynamic decay. | +| `decay_tau` | `float` | `2000` | Decay tau for dynamic decay. | +| `device` | `str` | `None` | Device to use for the moving average. If `None`, the device is inferred from the model's device | From 01918f00885a3295c0a83c5b6f8bd780ba4f9542 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:55:27 +0000 Subject: [PATCH 10/17] fix: type issues --- luxonis_train/callbacks/ema.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 254322aa..98f17d6b 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -171,7 +171,8 @@ def on_validation_epoch_start( self.collected_state_dict = deepcopy(pl_module.state_dict()) - pl_module.load_state_dict(self.ema.state_dict_ema) + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) def on_validation_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -183,7 +184,8 @@ def on_validation_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - pl_module.load_state_dict(self.collected_state_dict) + if self.collected_state_dict is not None: + pl_module.load_state_dict(self.collected_state_dict) def on_train_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -195,8 +197,9 @@ def on_train_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - pl_module.load_state_dict(self.ema.state_dict_ema) - logger.info("Model weights replaced with the EMA weights.") + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) + logger.info("Model weights replaced with the EMA weights.") def on_save_checkpoint( self, @@ -213,7 +216,8 @@ def on_save_checkpoint( @type checkpoint: dict @param checkpoint: Pytorch Lightning checkpoint. """ - checkpoint["state_dict"] = self.ema.state_dict_ema + if self.ema is not None: + checkpoint["state_dict"] = self.ema.state_dict_ema def on_load_checkpoint(self, callback_state: dict) -> None: """Load the EMA state_dict from the checkpoint. @@ -221,4 +225,5 @@ def on_load_checkpoint(self, callback_state: dict) -> None: @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ - self.ema.state_dict_ema = callback_state["state_dict"] + if self.ema is not None: + self.ema.state_dict_ema = callback_state["state_dict"] From 2c57e0338fd0c2cf57a4c8663fe74563de9d4fff Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 15:07:38 +0000 Subject: [PATCH 11/17] test: fix test_ema_update_on_batch_end --- tests/unittests/test_callbacks/test_ema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index 21eee4b1..26cf9654 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -44,10 +44,12 @@ def test_ema_update_on_batch_end(model, ema_callback): k: v.clone() for k, v in ema_callback.ema.state_dict_ema.items() } - outputs = None # Use a dummy output batch = torch.rand(2, 2) batch_idx = 0 + model.train() + outputs = model(batch) + ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) # Check that the EMA has been updated From 5ad37372c45a8571901779a9510c278d09623ce2 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 15:21:54 +0000 Subject: [PATCH 12/17] test: fix test_ema_update_on_batch_end --- tests/unittests/test_callbacks/test_ema.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index 26cf9654..0780e783 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -38,6 +38,7 @@ def test_ema_initialization(model, ema_callback): def test_ema_update_on_batch_end(model, ema_callback): trainer = Trainer() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) ema_callback.on_fit_start(trainer, model) initial_ema_state = { @@ -49,6 +50,9 @@ def test_ema_update_on_batch_end(model, ema_callback): model.train() outputs = model(batch) + model.zero_grad() + outputs.sum().backward() + optimizer.step() ema_callback.on_train_batch_end(trainer, model, outputs, batch, batch_idx) From 2e39e91801ac18ba14ee309133beb5905ec4099c Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 15:33:54 +0000 Subject: [PATCH 13/17] fix: type issue --- luxonis_train/callbacks/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 98f17d6b..bdee800d 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -154,8 +154,8 @@ def on_train_batch_end( @type batch_idx: int @param batch_idx: Batch index. """ - - self.ema.update(pl_module) # type: ignore + if self.ema is not None: + self.ema.update(pl_module) def on_validation_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule From 78a22f25d9be1841df9392ed86773626788863cf Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 15:49:56 +0000 Subject: [PATCH 14/17] feat: reorder EMA callback to be the first callback --- luxonis_train/config/config.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 604b9c31..eab6b3df 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -411,6 +411,21 @@ def check_validation_interval(self) -> Self: self.validation_interval = self.epochs return self + @model_validator(mode="after") + def reorder_callbacks(self) -> Self: + """Reorder callbacks so that EMA is the first callback, since it + needs to be updated before other callbacks.""" + ema_index = None + for i, callback in enumerate(self.callbacks): + if callback.name == "EMACallback": + ema_index = i + break + if ema_index is not None: + ema_callback = self.callbacks.pop(ema_index) + self.callbacks.insert(0, ema_callback) + + return self + class OnnxExportConfig(BaseModelExtraForbid): opset_version: PositiveInt = 12 From d3d7ad0ad737b463e3dbc5ebcc15704a585f7229 Mon Sep 17 00:00:00 2001 From: Nikita Date: Tue, 15 Oct 2024 10:53:46 +0000 Subject: [PATCH 15/17] refactor: simplify reorder_callbacks() --- luxonis_train/config/config.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index eab6b3df..7669ed39 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -415,15 +415,7 @@ def check_validation_interval(self) -> Self: def reorder_callbacks(self) -> Self: """Reorder callbacks so that EMA is the first callback, since it needs to be updated before other callbacks.""" - ema_index = None - for i, callback in enumerate(self.callbacks): - if callback.name == "EMACallback": - ema_index = i - break - if ema_index is not None: - ema_callback = self.callbacks.pop(ema_index) - self.callbacks.insert(0, ema_callback) - + self.callbacks.sort(key=lambda v: 0 if v.name == "EMACallback" else 1) return self From 42a92db528006ef986f5a01382354d7d01887753 Mon Sep 17 00:00:00 2001 From: Nikita Date: Wed, 16 Oct 2024 12:40:00 +0000 Subject: [PATCH 16/17] feat: speedup EMA computation --- luxonis_train/callbacks/ema.py | 54 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index bdee800d..5ff580e4 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -60,26 +60,43 @@ def __init__( def update(self, model: pl.LightningModule) -> None: """Update the stored parameters using a moving average. + Source: U{} + + @license: U{Apache License 2.0} + @type model: L{pl.LightningModule} @param model: Pytorch Lightning module. """ with torch.no_grad(): - for k, ema_p in self.state_dict_ema.items(): - if ema_p.dtype.is_floating_point: - self.updates += 1 - - if self.use_dynamic_decay: - decay = self.decay * ( - 1 - math.exp(-self.updates / self.decay_tau) - ) - else: - decay = self.decay - - model_p = model.state_dict()[k] - if self.device is not None: - model_p = model_p.to(device=self.device) - ema_p *= decay - ema_p += (1.0 - decay) * model_p + self.updates += 1 + + if self.use_dynamic_decay: + decay = self.decay * ( + 1 - math.exp(-self.updates / self.decay_tau) + ) + else: + decay = self.decay + + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip( + self.state_dict_ema.values(), model.state_dict().values() + ): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) + else: + ema_v.copy_(model_v) + + if hasattr(torch, "_foreach_lerp_"): + torch._foreach_lerp_( + ema_lerp_values, model_lerp_values, weight=1.0 - decay + ) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_( + ema_lerp_values, model_lerp_values, alpha=1.0 - decay + ) class EMACallback(Callback): @@ -154,8 +171,9 @@ def on_train_batch_end( @type batch_idx: int @param batch_idx: Batch index. """ - if self.ema is not None: - self.ema.update(pl_module) + if batch_idx % trainer.accumulate_grad_batches == 0: + if self.ema is not None: + self.ema.update(pl_module) def on_validation_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule From f232097dfdda428c3a582cfcc5cf00a1c5e8f97c Mon Sep 17 00:00:00 2001 From: Nikita Date: Wed, 16 Oct 2024 12:53:53 +0000 Subject: [PATCH 17/17] fix: move params to correct device --- luxonis_train/callbacks/ema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 5ff580e4..20c01c04 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -83,6 +83,8 @@ def update(self, model: pl.LightningModule) -> None: self.state_dict_ema.values(), model.state_dict().values() ): if ema_v.is_floating_point(): + if self.device is not None: + model_v = model_v.to(device=self.device) ema_lerp_values.append(ema_v) model_lerp_values.append(model_v) else: