Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ema checkpoint loading #156

Merged
merged 6 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,39 +202,40 @@ loader:

Here you can change everything related to actual training of the model.

| Key | Type | Default value | Description |
| ------------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
| `seed` | `int` | `None` | Seed for reproducibility |
| `deterministic` | `bool \| "warn" \| None` | `None` | Whether PyTorch should use deterministic backend |
| `batch_size` | `int` | `32` | Batch size used for training |
| `accumulate_grad_batches` | `int` | `1` | Number of batches for gradient accumulation |
| `gradient_clip_val` | `NonNegativeFloat \| None` | `None` | Value for gradient clipping. If `None`, gradient clipping is disabled. Clipping can help prevent exploding gradients. |
| `gradient_clip_algorithm` | `Literal["norm", "value"] \| None` | `None` | Algorithm to use for gradient clipping. Options are `"norm"` (clip by norm) or `"value"` (clip element-wise). |
| `use_weighted_sampler` | `bool` | `False` | Whether to use `WeightedRandomSampler` for training, only works with classification tasks |
| `epochs` | `int` | `100` | Number of training epochs |
| `n_workers` | `int` | `4` | Number of workers for data loading |
| `validation_interval` | `int` | `5` | Frequency of computing metrics on validation data |
| `n_log_images` | `int` | `4` | Maximum number of images to visualize and log |
| `skip_last_batch` | `bool` | `True` | Whether to skip last batch while training |
| `accelerator` | `Literal["auto", "cpu", "gpu"]` | `"auto"` | What accelerator to use for training |
| `devices` | `int \| list[int] \| str` | `"auto"` | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator |
| `matmul_precision` | `Literal["medium", "high", "highest"] \| None` | `None` | Sets the internal precision of float32 matrix multiplications |
| `strategy` | `Literal["auto", "ddp"]` | `"auto"` | What strategy to use for training |
| `n_sanity_val_steps` | `int` | `2` | Number of sanity validation steps performed before training |
| `profiler` | `Literal["simple", "advanced"] \| None` | `None` | PL profiler for GPU/CPU/RAM utilization analysis |
| `verbose` | `bool` | `True` | Print all intermediate results to console |
| `pin_memory` | `bool` | `True` | Whether to pin memory in the `DataLoader` |
| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training |
| `n_validation_batches` | `PositiveInt \| None` | `None` | Limits the number of validation/test batches and makes the val/test loaders deterministic |
| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings |
| Key | Type | Default value | Description |
| ------------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `seed` | `int` | `None` | Seed for reproducibility |
| `deterministic` | `bool \| "warn" \| None` | `None` | Whether PyTorch should use deterministic backend |
| `batch_size` | `int` | `32` | Batch size used for training |
| `accumulate_grad_batches` | `int` | `1` | Number of batches for gradient accumulation |
| `gradient_clip_val` | `NonNegativeFloat \| None` | `None` | Value for gradient clipping. If `None`, gradient clipping is disabled. Clipping can help prevent exploding gradients. |
| `gradient_clip_algorithm` | `Literal["norm", "value"] \| None` | `None` | Algorithm to use for gradient clipping. Options are `"norm"` (clip by norm) or `"value"` (clip element-wise). |
| `use_weighted_sampler` | `bool` | `False` | Whether to use `WeightedRandomSampler` for training, only works with classification tasks |
| `epochs` | `int` | `100` | Number of training epochs |
| `n_workers` | `int` | `4` | Number of workers for data loading |
| `validation_interval` | `int` | `5` | Frequency of computing metrics on validation data |
| `n_log_images` | `int` | `4` | Maximum number of images to visualize and log |
| `skip_last_batch` | `bool` | `True` | Whether to skip last batch while training |
| `accelerator` | `Literal["auto", "cpu", "gpu"]` | `"auto"` | What accelerator to use for training |
| `devices` | `int \| list[int] \| str` | `"auto"` | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator |
| `matmul_precision` | `Literal["medium", "high", "highest"] \| None` | `None` | Sets the internal precision of float32 matrix multiplications |
| `strategy` | `Literal["auto", "ddp"]` | `"auto"` | What strategy to use for training |
| `n_sanity_val_steps` | `int` | `2` | Number of sanity validation steps performed before training |
| `profiler` | `Literal["simple", "advanced"] \| None` | `None` | PL profiler for GPU/CPU/RAM utilization analysis |
| `verbose` | `bool` | `True` | Print all intermediate results to console |
| `pin_memory` | `bool` | `True` | Whether to pin memory in the `DataLoader` |
| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training |
| `n_validation_batches` | `PositiveInt \| None` | `None` | Limits the number of validation/test batches and makes the val/test loaders deterministic |
| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings |
| `resume_training` | `bool` | `False` | Whether to resume training from a checkpoint. Loads not only the weights from `model.weights` but also the `optimizer state`, `scheduler state`, and other training parameters to continue seamlessly. |

```yaml

trainer:
accelerator: "auto"
devices: "auto"
strategy: "auto"

resume_training: true
n_sanity_val_steps: 1
profiler: null
verbose: true
Expand All @@ -250,6 +251,21 @@ trainer:
smart_cfg_auto_populate: true
```

### Model Fine-Tuning Options

#### 1. **Fine-Tuning with Custom Configuration Example**

- Do **not** set the `resume_training` flag to `true`.
- Specify a **new LR** in the config (e.g., `0.1`), overriding the previous LR (e.g., `0.001`).
- Training starts at the new LR, with the scheduler/optimizer reset (can use different schedulers/optimizers than the base run).

#### 2. **Resume Training Continuously Example**

- Use the `resume_training` flag to continue training from the last checkpoint specified in `model.weights`.
- LR starts at the value where the previous run ended, maintaining continuity in the scheduler (e.g., combined LR plot would shows a seamless curve).
- For example:
- Resuming training with extended epochs (e.g., 400 epochs after 300) and adjusted `T_max` (e.g., 400 after 300 for cosine annealing) and `eta_min` (e.g., 10x less than before) will use the final learning rate (LR) from the previous run. This ignores the initial LR specified in the config and finishes with the new `eta_min` LR.

### Smart Configuration Auto-population

When setting `trainer.smart_cfg_auto_populate = True`, the following set of rules will be applied:
Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class _ViewType(str, Enum):
@app.command()
def train(
config: ConfigType = None,
resume: Annotated[
resume_weights: Annotated[
str | None,
typer.Option(help="Resume training from this checkpoint."),
] = None,
Expand All @@ -79,7 +79,7 @@ def train(
"""Start training."""
from luxonis_train.core import LuxonisModel

LuxonisModel(config, opts).train(resume_weights=resume)
LuxonisModel(config, opts).train(resume_weights=resume_weights)


@app.command()
Expand Down
35 changes: 17 additions & 18 deletions luxonis_train/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
decay: float = 0.9999,
use_dynamic_decay: bool = True,
decay_tau: float = 2000,
device: str | None = None,
):
"""Constructs `ModelEma`.

Expand All @@ -37,8 +36,6 @@ def __init__(
@param use_dynamic_decay: Use dynamic decay rate.
@type decay_tau: float
@param decay_tau: Decay tau for the moving average.
@type device: str | None
@param device: Device to perform EMA on.
"""
super(ModelEma, self).__init__()
model.eval()
Expand All @@ -51,11 +48,6 @@ def __init__(
self.decay = decay
self.use_dynamic_decay = use_dynamic_decay
self.decay_tau = decay_tau
self.device = device
if self.device is not None:
self.state_dict_ema = {
k: v.to(device=device) for k, v in self.state_dict_ema.items()
}

def update(self, model: pl.LightningModule) -> None:
"""Update the stored parameters using a moving average.
Expand Down Expand Up @@ -83,8 +75,6 @@ def update(self, model: pl.LightningModule) -> None:
self.state_dict_ema.values(), model.state_dict().values()
):
if ema_v.is_floating_point():
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_lerp_values.append(ema_v)
model_lerp_values.append(model_v)
else:
Expand All @@ -110,7 +100,6 @@ def __init__(
decay: float = 0.5,
use_dynamic_decay: bool = True,
decay_tau: float = 2000,
device: str | None = None,
):
"""Constructs `EMACallback`.

Expand All @@ -121,15 +110,13 @@ def __init__(
decay rate will be updated based on the number of updates.
@type decay_tau: float
@param decay_tau: Decay tau for the moving average.
@type device: str | None
@param device: Device to perform EMA on.
"""
self.decay = decay
self.use_dynamic_decay = use_dynamic_decay
self.decay_tau = decay_tau
self.device = device

self.ema = None
self.loaded_ema_state_dict = None
self.collected_state_dict = None

def on_fit_start(
Expand All @@ -149,8 +136,15 @@ def on_fit_start(
decay=self.decay,
use_dynamic_decay=self.use_dynamic_decay,
decay_tau=self.decay_tau,
device=self.device,
)
if self.loaded_ema_state_dict is not None:
target_device = next(iter(self.ema.state_dict_ema.values())).device
self.loaded_ema_state_dict = {
k: v.to(target_device)
for k, v in self.loaded_ema_state_dict.items()
}
self.ema.state_dict_ema = self.loaded_ema_state_dict
self.loaded_ema_state_dict = None

def on_train_batch_end(
self,
Expand Down Expand Up @@ -239,11 +233,16 @@ def on_save_checkpoint(
if self.ema is not None:
checkpoint["state_dict"] = self.ema.state_dict_ema

def on_load_checkpoint(self, callback_state: dict) -> None:
def on_load_checkpoint(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback_state: dict,
) -> None:
"""Load the EMA state_dict from the checkpoint.

@type callback_state: dict
@param callback_state: Pytorch Lightning callback state.
"""
if self.ema is not None:
self.ema.state_dict_ema = callback_state["state_dict"]
if callback_state and "state_dict" in callback_state:
self.loaded_ema_state_dict = callback_state["state_dict"]
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class TrainerConfig(BaseModelExtraForbid):
gradient_clip_algorithm: Literal["norm", "value"] | None = None
use_weighted_sampler: bool = False
epochs: PositiveInt = 100
resume_training: bool = False
n_workers: Annotated[
NonNegativeInt,
Field(validation_alias=AliasChoices("n_workers", "num_workers")),
Expand Down
Loading
Loading