Skip to content

Commit

Permalink
Mixed Precision Training (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Oct 31, 2024
1 parent 4fe0570 commit 1d83a2c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@ def forward(
self.alpha * pred_score.pow(self.gamma) * (1 - label)
+ target_score * label
)
ce_loss = F.binary_cross_entropy(
pred_score.float(), target_score.float(), reduction="none"
)
with torch.amp.autocast(
device_type=pred_score.device.type, enabled=False
):
ce_loss = F.binary_cross_entropy(
pred_score.float(), target_score.float(), reduction="none"
)
loss = (ce_loss * weight).sum()
return loss
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()
use_rich_progress_bar: bool = True

precision: Literal["16-mixed", "32"] = "32"
accelerator: Literal["auto", "cpu", "gpu", "tpu"] = "auto"
devices: int | list[int] | str = "auto"
strategy: Literal["auto", "ddp"] = "auto"
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
callbacks=LuxonisRichProgressBar()
if self.cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar(),
precision=self.cfg.trainer.precision,
)

self.train_augmentations = Augmentations(
Expand Down
15 changes: 7 additions & 8 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,13 @@ def forward(
node_inputs.append(computed[pred])
else:
node_inputs.append({"features": [inputs[pred]]})

outputs = node.run(node_inputs)

computed[node_name] = outputs

del node_inputs

if (
compute_loss
and node_name in self.losses
Expand All @@ -420,20 +424,15 @@ def forward(
node_name
].items():
viz = combine_visualizations(
visualizer.run(
images,
images,
outputs,
labels,
),
visualizer.run(images, images, outputs, labels),
)
visualizations[node_name][viz_name] = viz

for computed_name in list(computed.keys()):
if computed_name in self.outputs:
continue
for node_name in unprocessed:
if computed_name in self.graph[node_name]:
for unprocessed_name in unprocessed:
if computed_name in self.graph[unprocessed_name]:
break
else:
del computed[computed_name]
Expand Down

0 comments on commit 1d83a2c

Please sign in to comment.