From 859b54824f0df603b63caff84789887a708513cc Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 22 Jan 2025 11:15:41 +0100 Subject: [PATCH 1/6] fix ema checkpoint loading order --- luxonis_train/callbacks/ema.py | 35 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 20c01c04..598c72e6 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -25,7 +25,6 @@ def __init__( decay: float = 0.9999, use_dynamic_decay: bool = True, decay_tau: float = 2000, - device: str | None = None, ): """Constructs `ModelEma`. @@ -37,8 +36,6 @@ def __init__( @param use_dynamic_decay: Use dynamic decay rate. @type decay_tau: float @param decay_tau: Decay tau for the moving average. - @type device: str | None - @param device: Device to perform EMA on. """ super(ModelEma, self).__init__() model.eval() @@ -51,11 +48,6 @@ def __init__( self.decay = decay self.use_dynamic_decay = use_dynamic_decay self.decay_tau = decay_tau - self.device = device - if self.device is not None: - self.state_dict_ema = { - k: v.to(device=device) for k, v in self.state_dict_ema.items() - } def update(self, model: pl.LightningModule) -> None: """Update the stored parameters using a moving average. @@ -83,8 +75,6 @@ def update(self, model: pl.LightningModule) -> None: self.state_dict_ema.values(), model.state_dict().values() ): if ema_v.is_floating_point(): - if self.device is not None: - model_v = model_v.to(device=self.device) ema_lerp_values.append(ema_v) model_lerp_values.append(model_v) else: @@ -110,7 +100,6 @@ def __init__( decay: float = 0.5, use_dynamic_decay: bool = True, decay_tau: float = 2000, - device: str | None = None, ): """Constructs `EMACallback`. @@ -121,15 +110,13 @@ def __init__( decay rate will be updated based on the number of updates. @type decay_tau: float @param decay_tau: Decay tau for the moving average. - @type device: str | None - @param device: Device to perform EMA on. """ self.decay = decay self.use_dynamic_decay = use_dynamic_decay self.decay_tau = decay_tau - self.device = device self.ema = None + self.loaded_ema_state_dict = None self.collected_state_dict = None def on_fit_start( @@ -149,8 +136,15 @@ def on_fit_start( decay=self.decay, use_dynamic_decay=self.use_dynamic_decay, decay_tau=self.decay_tau, - device=self.device, ) + if self.loaded_ema_state_dict is not None: + target_device = next(iter(self.ema.state_dict_ema.values())).device + self.loaded_ema_state_dict = { + k: v.to(target_device) + for k, v in self.loaded_ema_state_dict.items() + } + self.ema.state_dict_ema = self.loaded_ema_state_dict + self.loaded_ema_state_dict = None def on_train_batch_end( self, @@ -239,11 +233,16 @@ def on_save_checkpoint( if self.ema is not None: checkpoint["state_dict"] = self.ema.state_dict_ema - def on_load_checkpoint(self, callback_state: dict) -> None: + def on_load_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + callback_state: dict, + ) -> None: """Load the EMA state_dict from the checkpoint. @type callback_state: dict @param callback_state: Pytorch Lightning callback state. """ - if self.ema is not None: - self.ema.state_dict_ema = callback_state["state_dict"] + if callback_state and "state_dict" in callback_state: + self.loaded_ema_state_dict = callback_state["state_dict"] From 181635f2f0e995cc647a598488cecb28f1be03d1 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 22 Jan 2025 11:31:25 +0100 Subject: [PATCH 2/6] fix ema test --- tests/unittests/test_callbacks/test_ema.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index 0780e783..ab7d7b6b 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -33,7 +33,6 @@ def test_ema_initialization(model, ema_callback): assert isinstance(ema_callback.ema, ModelEma) assert ema_callback.ema.decay == ema_callback.decay assert ema_callback.ema.use_dynamic_decay == ema_callback.use_dynamic_decay - assert ema_callback.ema.device == ema_callback.device def test_ema_update_on_batch_end(model, ema_callback): @@ -76,11 +75,10 @@ def test_ema_state_saved_to_checkpoint(model, ema_callback): def test_load_from_checkpoint(model, ema_callback): trainer = Trainer() - ema_callback.on_fit_start(trainer, model) checkpoint = {"state_dict": deepcopy(model.state_dict())} - ema_callback.on_load_checkpoint(checkpoint) - + ema_callback.on_load_checkpoint(trainer, model, checkpoint) + ema_callback.on_fit_start(trainer, model) assert ( ema_callback.ema.state_dict_ema.keys() == checkpoint["state_dict"].keys() From 7ffc7ef6166393d2b61c8b05ecc67642687a07af Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 22 Jan 2025 13:26:11 +0100 Subject: [PATCH 3/6] new resume_training param --- configs/README.md | 68 +++++++++++++++++++++------------- luxonis_train/__main__.py | 4 +- luxonis_train/config/config.py | 1 + luxonis_train/core/core.py | 26 ++++++++----- 4 files changed, 62 insertions(+), 37 deletions(-) diff --git a/configs/README.md b/configs/README.md index 69d77243..a941852e 100644 --- a/configs/README.md +++ b/configs/README.md @@ -202,31 +202,32 @@ loader: Here you can change everything related to actual training of the model. -| Key | Type | Default value | Description | -| ------------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ | -| `seed` | `int` | `None` | Seed for reproducibility | -| `deterministic` | `bool \| "warn" \| None` | `None` | Whether PyTorch should use deterministic backend | -| `batch_size` | `int` | `32` | Batch size used for training | -| `accumulate_grad_batches` | `int` | `1` | Number of batches for gradient accumulation | -| `gradient_clip_val` | `NonNegativeFloat \| None` | `None` | Value for gradient clipping. If `None`, gradient clipping is disabled. Clipping can help prevent exploding gradients. | -| `gradient_clip_algorithm` | `Literal["norm", "value"] \| None` | `None` | Algorithm to use for gradient clipping. Options are `"norm"` (clip by norm) or `"value"` (clip element-wise). | -| `use_weighted_sampler` | `bool` | `False` | Whether to use `WeightedRandomSampler` for training, only works with classification tasks | -| `epochs` | `int` | `100` | Number of training epochs | -| `n_workers` | `int` | `4` | Number of workers for data loading | -| `validation_interval` | `int` | `5` | Frequency of computing metrics on validation data | -| `n_log_images` | `int` | `4` | Maximum number of images to visualize and log | -| `skip_last_batch` | `bool` | `True` | Whether to skip last batch while training | -| `accelerator` | `Literal["auto", "cpu", "gpu"]` | `"auto"` | What accelerator to use for training | -| `devices` | `int \| list[int] \| str` | `"auto"` | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator | -| `matmul_precision` | `Literal["medium", "high", "highest"] \| None` | `None` | Sets the internal precision of float32 matrix multiplications | -| `strategy` | `Literal["auto", "ddp"]` | `"auto"` | What strategy to use for training | -| `n_sanity_val_steps` | `int` | `2` | Number of sanity validation steps performed before training | -| `profiler` | `Literal["simple", "advanced"] \| None` | `None` | PL profiler for GPU/CPU/RAM utilization analysis | -| `verbose` | `bool` | `True` | Print all intermediate results to console | -| `pin_memory` | `bool` | `True` | Whether to pin memory in the `DataLoader` | -| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training | -| `n_validation_batches` | `PositiveInt \| None` | `None` | Limits the number of validation/test batches and makes the val/test loaders deterministic | -| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings | +| Key | Type | Default value | Description | +| ------------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `seed` | `int` | `None` | Seed for reproducibility | +| `deterministic` | `bool \| "warn" \| None` | `None` | Whether PyTorch should use deterministic backend | +| `batch_size` | `int` | `32` | Batch size used for training | +| `accumulate_grad_batches` | `int` | `1` | Number of batches for gradient accumulation | +| `gradient_clip_val` | `NonNegativeFloat \| None` | `None` | Value for gradient clipping. If `None`, gradient clipping is disabled. Clipping can help prevent exploding gradients. | +| `gradient_clip_algorithm` | `Literal["norm", "value"] \| None` | `None` | Algorithm to use for gradient clipping. Options are `"norm"` (clip by norm) or `"value"` (clip element-wise). | +| `use_weighted_sampler` | `bool` | `False` | Whether to use `WeightedRandomSampler` for training, only works with classification tasks | +| `epochs` | `int` | `100` | Number of training epochs | +| `n_workers` | `int` | `4` | Number of workers for data loading | +| `validation_interval` | `int` | `5` | Frequency of computing metrics on validation data | +| `n_log_images` | `int` | `4` | Maximum number of images to visualize and log | +| `skip_last_batch` | `bool` | `True` | Whether to skip last batch while training | +| `accelerator` | `Literal["auto", "cpu", "gpu"]` | `"auto"` | What accelerator to use for training | +| `devices` | `int \| list[int] \| str` | `"auto"` | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator | +| `matmul_precision` | `Literal["medium", "high", "highest"] \| None` | `None` | Sets the internal precision of float32 matrix multiplications | +| `strategy` | `Literal["auto", "ddp"]` | `"auto"` | What strategy to use for training | +| `n_sanity_val_steps` | `int` | `2` | Number of sanity validation steps performed before training | +| `profiler` | `Literal["simple", "advanced"] \| None` | `None` | PL profiler for GPU/CPU/RAM utilization analysis | +| `verbose` | `bool` | `True` | Print all intermediate results to console | +| `pin_memory` | `bool` | `True` | Whether to pin memory in the `DataLoader` | +| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training | +| `n_validation_batches` | `PositiveInt \| None` | `None` | Limits the number of validation/test batches and makes the val/test loaders deterministic | +| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings | +| `resume_training` | `bool` | `False` | Whether to resume training from a checkpoint. Loads not only the weights from `model.weights` but also the `optimizer state`, `scheduler state`, and other training parameters to continue seamlessly. | ```yaml @@ -234,7 +235,7 @@ trainer: accelerator: "auto" devices: "auto" strategy: "auto" - + resume_training: true n_sanity_val_steps: 1 profiler: null verbose: true @@ -250,6 +251,21 @@ trainer: smart_cfg_auto_populate: true ``` +### Model Fine-Tuning Options + +#### 1. **Fine-Tuning with Custom Configuration Example** + +- Do **not** set the `resume_training` flag to `true`. +- Specify a **new LR** in the config (e.g., `0.1`), overriding the previous LR (e.g., `0.001`). +- Training starts at the new LR, with the scheduler/optimizer reset (can use different schedulers/optimizers than the base run). + +#### 2. **Resume Training Continuously Example** + +- Use the `resume_training` flag to continue training from the last checkpoint specified in `model.weights`. +- LR starts at the value where the previous run ended, maintaining continuity in the scheduler (e.g., combined LR plot would shows a seamless curve). +- For example: + - Resuming training with extended epochs (e.g., 400 epochs after 300) and adjusted `T_max` (e.g., 400 after 300 for cosine annealing) and `eta_min` (e.g., 10x less than before) will use the final learning rate (LR) from the previous run. This ignores the initial LR specified in the config and finishes with the new `eta_min` LR. + ### Smart Configuration Auto-population When setting `trainer.smart_cfg_auto_populate = True`, the following set of rules will be applied: diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index c0aae2dc..a260786c 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -70,7 +70,7 @@ class _ViewType(str, Enum): @app.command() def train( config: ConfigType = None, - resume: Annotated[ + resume_checkpoint: Annotated[ str | None, typer.Option(help="Resume training from this checkpoint."), ] = None, @@ -79,7 +79,7 @@ def train( """Start training.""" from luxonis_train.core import LuxonisModel - LuxonisModel(config, opts).train(resume_weights=resume) + LuxonisModel(config, opts).train(resume_checkpoint=resume_checkpoint) @app.command() diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index fcbbf24a..e313642c 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -380,6 +380,7 @@ class TrainerConfig(BaseModelExtraForbid): gradient_clip_algorithm: Literal["norm", "value"] | None = None use_weighted_sampler: bool = False epochs: PositiveInt = 100 + resume_training: bool = False n_workers: Annotated[ NonNegativeInt, Field(validation_alias=AliasChoices("n_workers", "num_workers")), diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 03ff5189..c92d2644 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -240,15 +240,15 @@ def _train(self, resume: str | None, *args, **kwargs): self.tracker._finalize(status) def train( - self, new_thread: bool = False, resume_weights: str | None = None + self, new_thread: bool = False, resume_checkpoint: str | None = None ) -> None: """Runs training. @type new_thread: bool @param new_thread: Runs training in new thread if set to True. - @type resume_weights: str | None - @param resume_weights: Path to the checkpoint from which to to - resume the training. + @type resume_checkpoint: str | None + @param resume_checkpoint: Path to the checkpoint from which to + to resume the training. """ if self.cfg.trainer.matmul_precision is not None: @@ -259,9 +259,11 @@ def train( self.cfg.trainer.matmul_precision ) - if resume_weights is not None: - resume_weights = str( - LuxonisFileSystem.download(resume_weights, self.run_save_dir) + if resume_checkpoint is not None: + resume_checkpoint = str( + LuxonisFileSystem.download( + resume_checkpoint, self.run_save_dir + ) ) def graceful_exit(signum: int, _): # pragma: no cover @@ -280,9 +282,15 @@ def graceful_exit(signum: int, _): # pragma: no cover if not new_thread: logger.info(f"Checkpoints will be saved in: {self.run_save_dir}") + if self.cfg.trainer.resume_training: + if resume_checkpoint is not None: + logger.warning( + "Resume weights provided in the command line, but resume_training in config is set to True. Ignoring resume weights provided in the command line." + ) + resume_checkpoint = self.cfg.model.weights logger.info("Starting training...") self._train( - resume_weights, + resume_checkpoint, self.lightning_module, self.pytorch_loaders["train"], self.pytorch_loaders["val"], @@ -300,7 +308,7 @@ def thread_exception_hook(args): self.thread = threading.Thread( target=self._train, args=( - resume_weights, + resume_checkpoint, self.lightning_module, self.pytorch_loaders["train"], self.pytorch_loaders["val"], From b152a6a4bd7c2ed63eb12f9a294e5f0141879295 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 23 Jan 2025 15:27:02 +0100 Subject: [PATCH 4/6] fix failing type-check --- luxonis_train/core/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index c92d2644..c4572a5a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -287,7 +287,7 @@ def graceful_exit(signum: int, _): # pragma: no cover logger.warning( "Resume weights provided in the command line, but resume_training in config is set to True. Ignoring resume weights provided in the command line." ) - resume_checkpoint = self.cfg.model.weights + resume_checkpoint = self.cfg.model.weights # type: ignore logger.info("Starting training...") self._train( resume_checkpoint, From 7fa34509dc997f82b0ee3610a579f9f63cf9b199 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 23 Jan 2025 19:10:03 +0100 Subject: [PATCH 5/6] rename to resume_weights --- luxonis_train/core/core.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index c4572a5a..b272ad15 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -240,15 +240,15 @@ def _train(self, resume: str | None, *args, **kwargs): self.tracker._finalize(status) def train( - self, new_thread: bool = False, resume_checkpoint: str | None = None + self, new_thread: bool = False, resume_weights: str | None = None ) -> None: """Runs training. @type new_thread: bool @param new_thread: Runs training in new thread if set to True. - @type resume_checkpoint: str | None - @param resume_checkpoint: Path to the checkpoint from which to - to resume the training. + @type resume_weights: str | None + @param resume_weights: Path to the checkpoint from which to to + resume the training. """ if self.cfg.trainer.matmul_precision is not None: @@ -259,11 +259,9 @@ def train( self.cfg.trainer.matmul_precision ) - if resume_checkpoint is not None: - resume_checkpoint = str( - LuxonisFileSystem.download( - resume_checkpoint, self.run_save_dir - ) + if resume_weights is not None: + resume_weights = str( + LuxonisFileSystem.download(resume_weights, self.run_save_dir) ) def graceful_exit(signum: int, _): # pragma: no cover @@ -283,14 +281,14 @@ def graceful_exit(signum: int, _): # pragma: no cover if not new_thread: logger.info(f"Checkpoints will be saved in: {self.run_save_dir}") if self.cfg.trainer.resume_training: - if resume_checkpoint is not None: + if resume_weights is not None: logger.warning( "Resume weights provided in the command line, but resume_training in config is set to True. Ignoring resume weights provided in the command line." ) - resume_checkpoint = self.cfg.model.weights # type: ignore + resume_weights = self.cfg.model.weights # type: ignore logger.info("Starting training...") self._train( - resume_checkpoint, + resume_weights, self.lightning_module, self.pytorch_loaders["train"], self.pytorch_loaders["val"], @@ -308,7 +306,7 @@ def thread_exception_hook(args): self.thread = threading.Thread( target=self._train, args=( - resume_checkpoint, + resume_weights, self.lightning_module, self.pytorch_loaders["train"], self.pytorch_loaders["val"], From fd09c8c566d8ff3b840025aabbc3515a77e248d2 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 23 Jan 2025 19:18:57 +0100 Subject: [PATCH 6/6] rename to resume_weights --- luxonis_train/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index a260786c..c6cd199c 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -70,7 +70,7 @@ class _ViewType(str, Enum): @app.command() def train( config: ConfigType = None, - resume_checkpoint: Annotated[ + resume_weights: Annotated[ str | None, typer.Option(help="Resume training from this checkpoint."), ] = None, @@ -79,7 +79,7 @@ def train( """Start training.""" from luxonis_train.core import LuxonisModel - LuxonisModel(config, opts).train(resume_checkpoint=resume_checkpoint) + LuxonisModel(config, opts).train(resume_weights=resume_weights) @app.command()