Skip to content

Commit

Permalink
fix: type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn committed Oct 14, 2024
1 parent 22bba41 commit 01918f0
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions luxonis_train/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def on_validation_epoch_start(

self.collected_state_dict = deepcopy(pl_module.state_dict())

pl_module.load_state_dict(self.ema.state_dict_ema)
if self.ema is not None:
pl_module.load_state_dict(self.ema.state_dict_ema)

def on_validation_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
Expand All @@ -183,7 +184,8 @@ def on_validation_end(
@type pl_module: L{pl.LightningModule}
@param pl_module: Pytorch Lightning module.
"""
pl_module.load_state_dict(self.collected_state_dict)
if self.collected_state_dict is not None:
pl_module.load_state_dict(self.collected_state_dict)

def on_train_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
Expand All @@ -195,8 +197,9 @@ def on_train_end(
@type pl_module: L{pl.LightningModule}
@param pl_module: Pytorch Lightning module.
"""
pl_module.load_state_dict(self.ema.state_dict_ema)
logger.info("Model weights replaced with the EMA weights.")
if self.ema is not None:
pl_module.load_state_dict(self.ema.state_dict_ema)

Check warning on line 201 in luxonis_train/callbacks/ema.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/callbacks/ema.py#L200-L201

Added lines #L200 - L201 were not covered by tests
logger.info("Model weights replaced with the EMA weights.")

def on_save_checkpoint(
self,
Expand All @@ -213,12 +216,14 @@ def on_save_checkpoint(
@type checkpoint: dict
@param checkpoint: Pytorch Lightning checkpoint.
"""
checkpoint["state_dict"] = self.ema.state_dict_ema
if self.ema is not None:
checkpoint["state_dict"] = self.ema.state_dict_ema

def on_load_checkpoint(self, callback_state: dict) -> None:
"""Load the EMA state_dict from the checkpoint.
@type callback_state: dict
@param callback_state: Pytorch Lightning callback state.
"""
self.ema.state_dict_ema = callback_state["state_dict"]
if self.ema is not None:
self.ema.state_dict_ema = callback_state["state_dict"]

0 comments on commit 01918f0

Please sign in to comment.