From 01918f00885a3295c0a83c5b6f8bd780ba4f9542 Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 14 Oct 2024 13:55:27 +0000 Subject: [PATCH] fix: type issues --- luxonis_train/callbacks/ema.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 254322aa..98f17d6b 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -171,7 +171,8 @@ def on_validation_epoch_start( self.collected_state_dict = deepcopy(pl_module.state_dict()) - pl_module.load_state_dict(self.ema.state_dict_ema) + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) def on_validation_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -183,7 +184,8 @@ def on_validation_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - pl_module.load_state_dict(self.collected_state_dict) + if self.collected_state_dict is not None: + pl_module.load_state_dict(self.collected_state_dict) def on_train_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -195,8 +197,9 @@ def on_train_end( @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. """ - pl_module.load_state_dict(self.ema.state_dict_ema) - logger.info("Model weights replaced with the EMA weights.") + if self.ema is not None: + pl_module.load_state_dict(self.ema.state_dict_ema) + logger.info("Model weights replaced with the EMA weights.") def on_save_checkpoint( self, @@ -213,7 +216,8 @@ def on_save_checkpoint( @type checkpoint: dict @param checkpoint: Pytorch Lightning checkpoint. """ - checkpoint["state_dict"] = self.ema.state_dict_ema + if self.ema is not None: + checkpoint["state_dict"] = self.ema.state_dict_ema def on_load_checkpoint(self, callback_state: dict) -> None: """Load the EMA state_dict from the checkpoint. @@ -221,4 +225,5 @@ def on_load_checkpoint(self, callback_state: dict) -> None: @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ - self.ema.state_dict_ema = callback_state["state_dict"] + if self.ema is not None: + self.ema.state_dict_ema = callback_state["state_dict"]