Skip to content

Commit

Permalink
Upload All Checkpoints (#19)
Browse files Browse the repository at this point in the history
* uploading all checkpoints

* fix names

* removed comment
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent 08110dc commit 7a857cc
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions luxonis_train/callbacks/upload_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from pathlib import Path
from typing import Any

import lightning.pytorch as pl
Expand All @@ -25,37 +26,41 @@ def __init__(self, upload_directory: str):
)
self.logger = logging.getLogger(__name__)
self.last_logged_epoch = None
self.last_best_checkpoint = None
self.last_best_checkpoints = set()

def on_save_checkpoint(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
_: pl.LightningModule,
checkpoint: dict[str, Any],
) -> None:
# Log only once per epoch in case there are multiple ModelCheckpoint callbacks
if not self.last_logged_epoch == trainer.current_epoch:
model_checkpoint_callbacks = [
c
checkpoint_paths = [
c.best_model_path
for c in trainer.callbacks # type: ignore
if isinstance(c, pl.callbacks.ModelCheckpoint) # type: ignore
and c.best_model_path
]
# NOTE: assume that first checkpoint callback is based on val loss
curr_best_checkpoint = model_checkpoint_callbacks[0].best_model_path

if self.last_best_checkpoint != curr_best_checkpoint:
self.logger.info(f"Started checkpoint upload to {self.fs.full_path}...")
temp_filename = "curr_best_val_loss.ckpt"
torch.save(checkpoint, temp_filename)
self.fs.put_file(
local_path=temp_filename,
remote_path=temp_filename,
mlflow_instance=trainer.logger.experiment.get( # type: ignore
"mlflow", None
),
)
os.remove(temp_filename)
self.logger.info("Checkpoint upload finished")
self.last_best_checkpoint = curr_best_checkpoint
for curr_best_checkpoint in checkpoint_paths:
if curr_best_checkpoint not in self.last_best_checkpoints:
self.logger.info(
f"Started checkpoint upload to {self.fs.full_path}..."
)
temp_filename = (
Path(curr_best_checkpoint).parent.with_suffix(".ckpt").name
)
torch.save(checkpoint, temp_filename)

self.fs.put_file(
local_path=temp_filename,
remote_path=temp_filename,
mlflow_instance=trainer.logger.experiment.get( # type: ignore
"mlflow", None
),
)
os.remove(temp_filename)
self.logger.info("Checkpoint upload finished")
self.last_best_checkpoints.add(curr_best_checkpoint)

self.last_logged_epoch = trainer.current_epoch

0 comments on commit 7a857cc

Please sign in to comment.