diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 32e25257..a81d5a45 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -270,8 +270,11 @@ def forward( self.alpha * pred_score.pow(self.gamma) * (1 - label) + target_score * label ) - ce_loss = F.binary_cross_entropy( - pred_score.float(), target_score.float(), reduction="none" - ) + with torch.amp.autocast( + device_type=pred_score.device.type, enabled=False + ): + ce_loss = F.binary_cross_entropy( + pred_score.float(), target_score.float(), reduction="none" + ) loss = (ce_loss * weight).sum() return loss diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 604b9c31..84424f72 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -340,6 +340,7 @@ class TrainerConfig(BaseModelExtraForbid): preprocessing: PreprocessingConfig = PreprocessingConfig() use_rich_progress_bar: bool = True + precision: Literal["16-mixed", "32"] = "32" accelerator: Literal["auto", "cpu", "gpu", "tpu"] = "auto" devices: int | list[int] | str = "auto" strategy: Literal["auto", "ddp"] = "auto" diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 7f21e5e4..e292ecdb 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -108,6 +108,7 @@ def __init__( callbacks=LuxonisRichProgressBar() if self.cfg.trainer.use_rich_progress_bar else LuxonisTQDMProgressBar(), + precision=self.cfg.trainer.precision, ) self.train_augmentations = Augmentations( diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index e10df8c0..aec8e06a 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -392,9 +392,13 @@ def forward( node_inputs.append(computed[pred]) else: node_inputs.append({"features": [inputs[pred]]}) + outputs = node.run(node_inputs) + computed[node_name] = outputs + del node_inputs + if ( compute_loss and node_name in self.losses @@ -420,20 +424,15 @@ def forward( node_name ].items(): viz = combine_visualizations( - visualizer.run( - images, - images, - outputs, - labels, - ), + visualizer.run(images, images, outputs, labels), ) visualizations[node_name][viz_name] = viz for computed_name in list(computed.keys()): if computed_name in self.outputs: continue - for node_name in unprocessed: - if computed_name in self.graph[node_name]: + for unprocessed_name in unprocessed: + if computed_name in self.graph[unprocessed_name]: break else: del computed[computed_name]