diff --git a/README.md b/README.md index 6d718afa..f6f82ed2 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,7 @@ model.tune() - [**Callbacks**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md): Allow custom code to be executed at different stages of training. - [**Optimizers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer): Control how the model's weights are updated. - [**Schedulers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler): Adjust the learning rate during training. +- [**Training Strategy**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#training-strategy): Specify a custom combination of optimizer and scheduler to tailor the training process for specific use cases. **Creating Custom Components:** @@ -581,6 +582,7 @@ Registered components can be referenced in the config file. Custom components ne - **Callbacks** - [`lightning.pytorch.callbacks.Callback`](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), requires manual registration to the `CALLBACKS` registry - **Optimizers** - [`torch.optim.Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), requires manual registration to the `OPTIMIZERS` registry - **Schedulers** - [`torch.optim.lr_scheduler.LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate), requires manual registration to the `SCHEDULERS` registry +- **Training Strategy** - [`BaseTrainingStrategy`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/strategies/base_strategy.py) **Examples:** diff --git a/configs/README.md b/configs/README.md index 8a9e1c01..e8281237 100644 --- a/configs/README.md +++ b/configs/README.md @@ -376,6 +376,37 @@ trainer: eta_min: 0 ``` +### Training Strategy + +Defines the training strategy to be used. Currently, only the `TripleLRSGDStrategy` is supported, but more strategies will be added in the future. + +| Key | Type | Default value | Description | +| ----------------- | ------- | ----------------------- | ---------------------------------------------- | +| `name` | `str` | `"TripleLRSGDStrategy"` | Name of the training strategy | +| `warmup_epochs` | `int` | `3` | Number of epochs for the warmup phase | +| `warmup_bias_lr` | `float` | `0.1` | Learning rate for bias during the warmup phase | +| `warmup_momentum` | `float` | `0.8` | Momentum value during the warmup phase | +| `lr` | `float` | `0.02` | Initial learning rate | +| `lre` | `float` | `0.0002` | End learning rate | +| `momentum` | `float` | `0.937` | Momentum for the optimizer | +| `weight_decay` | `float` | `0.0005` | Weight decay value | +| `nesterov` | `bool` | `true` | Use Nesterov momentum or not | + +**Example:** + +```yaml +training_strategy: + name: "TripleLRSGDStrategy" + warmup_epochs: 3 + warmup_bias_lr: 0.1 + warmup_momentum: 0.8 + lr: 0.02 + lre: 0.0002 + momentum: 0.937 + weight_decay: 0.0005 + nesterov: true +``` + ## Exporter Here you can define configuration for exporting. diff --git a/luxonis_train/__init__.py b/luxonis_train/__init__.py index ac6e38a1..e9651769 100644 --- a/luxonis_train/__init__.py +++ b/luxonis_train/__init__.py @@ -10,6 +10,7 @@ from .nodes import * from .optimizers import * from .schedulers import * + from .strategies import * from .utils import * except ImportError as e: warnings.warn( diff --git a/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py b/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py index 8c7252ee..f9f4150e 100644 --- a/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py @@ -57,6 +57,12 @@ def draw_predictions( prediction = predictions[i] mask = prediction[..., 2] < visibility_threshold visible_kpts = prediction[..., :2] * (~mask).unsqueeze(-1).float() + visible_kpts[..., 0] = torch.clamp( + visible_kpts[..., 0], 0, canvas.size(-1) - 1 + ) + visible_kpts[..., 1] = torch.clamp( + visible_kpts[..., 1], 0, canvas.size(-2) - 1 + ) viz[i] = draw_keypoints( canvas[i].clone(), visible_kpts[..., :2], diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index a3cf907c..7bea71a9 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -25,6 +25,7 @@ from .metadata_logger import MetadataLogger from .module_freezer import ModuleFreezer from .test_on_train_end import TestOnTrainEnd +from .training_manager import TrainingManager from .upload_checkpoint import UploadCheckpoint CALLBACKS.register_module(module=EarlyStopping) @@ -38,6 +39,7 @@ CALLBACKS.register_module(module=ModelPruning) CALLBACKS.register_module(module=GradCamCallback) CALLBACKS.register_module(module=EMACallback) +CALLBACKS.register_module(module=TrainingManager) __all__ = [ @@ -53,4 +55,5 @@ "GPUStatsMonitor", "GradCamCallback", "EMACallback", + "TrainingManager", ] diff --git a/luxonis_train/callbacks/training_manager.py b/luxonis_train/callbacks/training_manager.py new file mode 100644 index 00000000..9131fa84 --- /dev/null +++ b/luxonis_train/callbacks/training_manager.py @@ -0,0 +1,28 @@ +import pytorch_lightning as pl + +from luxonis_train.strategies.base_strategy import BaseTrainingStrategy + + +class TrainingManager(pl.Callback): + def __init__(self, strategy: BaseTrainingStrategy | None = None): + """Training manager callback that updates the parameters of the + training strategy. + + @type strategy: BaseTrainingStrategy + @param strategy: The strategy to be used. + """ + self.strategy = strategy + + def on_after_backward( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + """PyTorch Lightning hook that is called after the backward + pass. + + @type trainer: pl.Trainer + @param trainer: The trainer object. + @type pl_module: pl.LightningModule + @param pl_module: The pl_module object. + """ + if self.strategy is not None: + self.strategy.update_parameters(pl_module) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 67fdc8f0..941cd649 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -327,12 +327,17 @@ class CallbackConfig(BaseModelExtraForbid): class OptimizerConfig(BaseModelExtraForbid): - name: str = "Adam" + name: str params: Params = {} class SchedulerConfig(BaseModelExtraForbid): - name: str = "ConstantLR" + name: str + params: Params = {} + + +class TrainingStrategyConfig(BaseModelExtraForbid): + name: str params: Params = {} @@ -380,8 +385,9 @@ class TrainerConfig(BaseModelExtraForbid): callbacks: list[CallbackConfig] = [] - optimizer: OptimizerConfig = OptimizerConfig() - scheduler: SchedulerConfig = SchedulerConfig() + optimizer: OptimizerConfig | None = None + scheduler: SchedulerConfig | None = None + training_strategy: TrainingStrategyConfig | None = None @model_validator(mode="after") def validate_deterministic(self) -> Self: @@ -511,6 +517,7 @@ def get_config( ) -> "Config": instance = super().get_config(cfg, overrides) if not isinstance(cfg, str): + cls.smart_auto_populate(instance) return instance fs = LuxonisFileSystem(cfg) if fs.is_mlflow: @@ -530,16 +537,34 @@ def smart_auto_populate(cls, instance: "Config") -> None: """Automatically populates config fields based on rules, with warnings.""" + # Rule: Set default optimizer and scheduler if training_strategy is not defined and optimizer and scheduler are None + if instance.trainer.training_strategy is None: + if instance.trainer.optimizer is None: + instance.trainer.optimizer = OptimizerConfig( + name="Adam", params={} + ) + logger.warning( + "Optimizer not specified. Automatically set to `Adam`." + ) + if instance.trainer.scheduler is None: + instance.trainer.scheduler = SchedulerConfig( + name="ConstantLR", params={} + ) + logger.warning( + "Scheduler not specified. Automatically set to `ConstantLR`." + ) + # Rule: CosineAnnealingLR should have T_max set to the number of epochs if not provided - scheduler = instance.trainer.scheduler - if ( - scheduler.name == "CosineAnnealingLR" - and "T_max" not in scheduler.params - ): - scheduler.params["T_max"] = instance.trainer.epochs - logger.warning( - "`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs." - ) + if instance.trainer.scheduler is not None: + scheduler = instance.trainer.scheduler + if ( + scheduler.name == "CosineAnnealingLR" + and "T_max" not in scheduler.params + ): + scheduler.params["T_max"] = instance.trainer.epochs + logger.warning( + "`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs." + ) # Rule: Mosaic4 should have out_width and out_height matching train_image_size if not provided for augmentation in instance.trainer.preprocessing.augmentations: diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 011c3983..08d0066f 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -25,7 +25,11 @@ combine_visualizations, get_denormalized_images, ) -from luxonis_train.callbacks import BaseLuxonisProgressBar, ModuleFreezer +from luxonis_train.callbacks import ( + BaseLuxonisProgressBar, + ModuleFreezer, + TrainingManager, +) from luxonis_train.config import AttachedModuleConfig, Config from luxonis_train.nodes import BaseNode from luxonis_train.utils import ( @@ -42,6 +46,7 @@ CALLBACKS, OPTIMIZERS, SCHEDULERS, + STRATEGIES, Registry, ) @@ -268,6 +273,24 @@ def __init__( self.load_checkpoint(self.cfg.model.weights) + if self.cfg.trainer.training_strategy is not None: + if ( + self.cfg.trainer.optimizer is not None + or self.cfg.trainer.scheduler is not None + ): + raise ValueError( + "Training strategy is defined, but optimizer or scheduler is also defined. " + "Please remove optimizer and scheduler from the config." + ) + self.training_strategy = STRATEGIES.get( + self.cfg.trainer.training_strategy.name + )( + pl_module=self, + params=self.cfg.trainer.training_strategy.params, # type: ignore + ) + else: + self.training_strategy = None + @property def core(self) -> "luxonis_train.core.LuxonisModel": """Returns the core model.""" @@ -849,6 +872,9 @@ def configure_callbacks(self) -> list[pl.Callback]: CALLBACKS.get(callback.name)(**callback.params) ) + if self.training_strategy is not None: + callbacks.append(TrainingManager(strategy=self.training_strategy)) # type: ignore + return callbacks def configure_optimizers( @@ -858,9 +884,17 @@ def configure_optimizers( list[torch.optim.lr_scheduler.LRScheduler], ]: """Configures model optimizers and schedulers.""" + if self.training_strategy is not None: + return self.training_strategy.configure_optimizers() + cfg_optimizer = self.cfg.trainer.optimizer cfg_scheduler = self.cfg.trainer.scheduler + if cfg_optimizer is None or cfg_scheduler is None: + raise ValueError( + "Optimizer and scheduler configuration must not be None." + ) + optim_params = cfg_optimizer.params | { "params": filter(lambda p: p.requires_grad, self.parameters()), } diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index cef35029..7c7b53c4 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -82,16 +82,17 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). **Parameters:** -| Key | Type | Default value | Description | -| ------------------ | ----------------------------------------------------------------- | --------------------------- | -------------------------------------------------------------------------- | -| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network | -| `channels_list` | `list[int]` | \[64, 128, 256, 512, 1024\] | List of number of channels for each block | -| `n_repeats` | `list[int]` | \[1, 6, 12, 18, 6\] | List of number of repeats of `RepVGGBlock` | -| `depth_mul` | `float` | `0.33` | Depth multiplier | -| `width_mul` | `float` | `0.25` | Width multiplier | -| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used | -| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` | -| `download_weights` | `bool` | `True` | If True download weights from COCO (if available for specified variant) | +| Key | Type | Default value | Description | +| -------------------- | ----------------------------------------------------------------- | --------------------------- | -------------------------------------------------------------------------- | +| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network | +| `channels_list` | `list[int]` | \[64, 128, 256, 512, 1024\] | List of number of channels for each block | +| `n_repeats` | `list[int]` | \[1, 6, 12, 18, 6\] | List of number of repeats of `RepVGGBlock` | +| `depth_mul` | `float` | `0.33` | Depth multiplier | +| `width_mul` | `float` | `0.25` | Width multiplier | +| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used | +| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` | +| `download_weights` | `bool` | `True` | If True download weights from COCO (if available for specified variant) | +| `initialize_weights` | `bool` | `True` | If True, initialize weights. | ### RexNetV1_lite @@ -175,17 +176,18 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). **Parameters:** -| Key | Type | Default value | Description | -| ------------------ | ----------------------------------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------- | -| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network | -| `n_heads` | `Literal[2,3,4]` | `3` | Number of output heads. Should be same also on the connected head in most cases | -| `channels_list` | `list[int]` | `[256, 128, 128, 256, 256, 512]` | List of number of channels for each block | -| `n_repeats` | `list[int]` | `[12, 12, 12, 12]` | List of number of repeats of `RepVGGBlock` | -| `depth_mul` | `float` | `0.33` | Depth multiplier | -| `width_mul` | `float` | `0.25` | Width multiplier | -| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used | -| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` | -| `download_weights` | `bool` | `False` | If True download weights from COCO (if available for specified variant) | +| Key | Type | Default value | Description | +| -------------------- | ----------------------------------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------- | +| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network | +| `n_heads` | `Literal[2,3,4]` | `3` | Number of output heads. Should be same also on the connected head in most cases | +| `channels_list` | `list[int]` | `[256, 128, 128, 256, 256, 512]` | List of number of channels for each block | +| `n_repeats` | `list[int]` | `[12, 12, 12, 12]` | List of number of repeats of `RepVGGBlock` | +| `depth_mul` | `float` | `0.33` | Depth multiplier | +| `width_mul` | `float` | `0.25` | Width multiplier | +| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used | +| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` | +| `download_weights` | `bool` | `False` | If True download weights from COCO (if available for specified variant) | +| `initialize_weights` | `bool` | `True` | If True, initialize weights. | ## Heads @@ -217,13 +219,14 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). **Parameters:** -| Key | Type | Default value | Description | -| ------------------ | ------- | ------------- | --------------------------------------------------------------------- | -| `n_heads` | `bool` | `3` | Number of output heads | -| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | -| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) | -| `max_det` | `int` | `300` | Maximum number of detections retained after NMS | -| `download_weights` | `bool` | `False` | If True download weights from COCO | +| Key | Type | Default value | Description | +| -------------------- | ------- | ------------- | --------------------------------------------------------------------- | +| `n_heads` | `bool` | `3` | Number of output heads | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | +| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) | +| `max_det` | `int` | `300` | Maximum number of detections retained after NMS | +| `download_weights` | `bool` | `False` | If True download weights from COCO | +| `initialize_weights` | `bool` | `True` | If True, initialize weights. | ### `EfficientKeypointBBoxHead` diff --git a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py index d094da14..340cc10a 100644 --- a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py +++ b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py @@ -30,6 +30,7 @@ def __init__( block: Literal["RepBlock", "CSPStackRepBlock"] | None = None, csp_e: float | None = None, download_weights: bool = True, + initialize_weights: bool = True, **kwargs: Any, ): """Implementation of the EfficientRep backbone. Supports the @@ -65,6 +66,8 @@ def __init__( overrides the variant value. @type download_weights: bool @param download_weights: If True download weights from COCO (if available for specified variant). Defaults to True. + @type initialize_weights: bool + @param initialize_weights: If True, initialize weights of the model. """ super().__init__(**kwargs) @@ -125,9 +128,24 @@ def __init__( ) ) + if initialize_weights: + self.initialize_weights() + if download_weights and var.weights_path: self.load_checkpoint(var.weights_path) + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True + def set_export_mode(self, mode: bool = True) -> None: """Reparametrizes instances of L{RepVGGBlock} in the network. diff --git a/luxonis_train/nodes/heads/efficient_bbox_head.py b/luxonis_train/nodes/heads/efficient_bbox_head.py index c8500915..531a294f 100644 --- a/luxonis_train/nodes/heads/efficient_bbox_head.py +++ b/luxonis_train/nodes/heads/efficient_bbox_head.py @@ -30,6 +30,7 @@ def __init__( iou_thres: float = 0.45, max_det: int = 300, download_weights: bool = False, + initialize_weights: bool = True, **kwargs: Any, ): """Head for object detection. @@ -51,6 +52,8 @@ def __init__( @type download_weights: bool @param download_weights: If True download weights from COCO. Defaults to False. + @type initialize_weights: bool + @param initialize_weights: If True, initialize weights. """ super().__init__(**kwargs) @@ -95,12 +98,27 @@ def __init__( f"output{i+1}_yolov6r2" for i in range(self.n_heads) ] + if initialize_weights: + self.initialize_weights() + if download_weights: # TODO: Handle variants of head in a nicer way if self.in_channels == [32, 64, 128]: weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/efficientbbox_head_n_coco.ckpt" self.load_checkpoint(weights_path, strict=False) + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True + def forward( self, inputs: list[Tensor] ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: diff --git a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py index e6b321be..2c95890a 100644 --- a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py +++ b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py @@ -27,6 +27,7 @@ def __init__( block: Literal["RepBlock", "CSPStackRepBlock"] | None = None, csp_e: float | None = None, download_weights: bool = False, + initialize_weights: bool = True, **kwargs: Any, ): """Implementation of the RepPANNeck module. Supports the version @@ -65,6 +66,8 @@ def __init__( overrides the variant value. @type download_weights: bool @param download_weights: If True download weights from COCO (if available for specified variant). Defaults to False. + @type initialize_weights: bool + @param initialize_weights: If True, initialize weights of the model. """ super().__init__(**kwargs) @@ -165,9 +168,24 @@ def __init__( out_channels = channels_list_down_blocks[2 * i + 1] curr_n_repeats = n_repeats_down_blocks[i] + if initialize_weights: + self.initialize_weights() + if download_weights and var.weights_path: self.load_checkpoint(var.weights_path) + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True + def forward(self, inputs: list[Tensor]) -> list[Tensor]: x = inputs[-1] up_block_outs: list[Tensor] = [] diff --git a/luxonis_train/strategies/__init__.py b/luxonis_train/strategies/__init__.py new file mode 100644 index 00000000..b83d7190 --- /dev/null +++ b/luxonis_train/strategies/__init__.py @@ -0,0 +1,7 @@ +from .base_strategy import BaseTrainingStrategy +from .triple_lr_sgd import TripleLRScheduler + +__all__ = [ + "TripleLRScheduler", + "BaseTrainingStrategy", +] diff --git a/luxonis_train/strategies/base_strategy.py b/luxonis_train/strategies/base_strategy.py new file mode 100644 index 00000000..8de6386d --- /dev/null +++ b/luxonis_train/strategies/base_strategy.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +import pytorch_lightning as pl +from luxonis_ml.utils.registry import AutoRegisterMeta +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from luxonis_train.utils.registry import STRATEGIES + + +class BaseTrainingStrategy( + ABC, + metaclass=AutoRegisterMeta, + register=False, + registry=STRATEGIES, +): + def __init__(self, pl_module: pl.LightningModule): + self.pl_module = pl_module + + @abstractmethod + def configure_optimizers( + self, + ) -> tuple[list[Optimizer], list[LRScheduler]]: + pass + + @abstractmethod + def update_parameters(self, *args, **kwargs): + pass diff --git a/luxonis_train/strategies/triple_lr_sgd.py b/luxonis_train/strategies/triple_lr_sgd.py new file mode 100644 index 00000000..33f7dfe3 --- /dev/null +++ b/luxonis_train/strategies/triple_lr_sgd.py @@ -0,0 +1,171 @@ +# strategies/triple_lr_sgd.py +import math + +import numpy as np +import pytorch_lightning as pl +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import LambdaLR +from torch.optim.optimizer import Optimizer + +from .base_strategy import BaseTrainingStrategy + + +class TripleLRScheduler: + def __init__( + self, + optimizer: torch.optim.Optimizer, + params: dict, + epochs: int, + max_stepnum: int, + ) -> None: + """TripleLRScheduler scheduler. + + @type optimizer: torch.optim.Optimizer + @param optimizer: The optimizer to be used. + @type params: dict + @param params: The parameters for the scheduler. + @type epochs: int + @param epochs: The number of epochs to train for. + @type max_stepnum: int + @param max_stepnum: The maximum number of steps to train for. + """ + self.optimizer = optimizer + self.params = { + "warmup_epochs": 3, + "warmup_bias_lr": 0.1, + "warmup_momentum": 0.8, + "lre": 0.0002, + } + if params: + self.params.update(params) + self.max_stepnum = max_stepnum + self.warmup_stepnum = max( + round(self.params["warmup_epochs"] * self.max_stepnum), 1000 + ) + self.step = 0 + self.lrf = self.params["lre"] / self.optimizer.defaults["lr"] + self.lf = ( + lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) + * (self.lrf - 1) + + 1 + ) + + def create_scheduler(self): + scheduler = LambdaLR(self.optimizer, lr_lambda=self.lf) + return scheduler + + def update_learning_rate(self, current_epoch: int) -> None: + self.step = self.step % self.max_stepnum + curr_step = self.step + self.max_stepnum * current_epoch + + if curr_step <= self.warmup_stepnum: + for k, param in enumerate(self.optimizer.param_groups): + warmup_bias_lr = ( + self.params["warmup_bias_lr"] if k == 2 else 0.0 + ) + param["lr"] = np.interp( + curr_step, + [0, self.warmup_stepnum], + [ + warmup_bias_lr, + self.optimizer.defaults["lr"] * self.lf(current_epoch), + ], + ) + if "momentum" in param: + self.optimizer.defaults["momentum"] = np.interp( + curr_step, + [0, self.warmup_stepnum], + [ + self.params["warmup_momentum"], + self.optimizer.defaults["momentum"], + ], + ) + self.step += 1 + + +class TripleLRSGD: + def __init__(self, model: torch.nn.Module, params: dict) -> None: + """TripleLRSGD optimizer. + + @type model: torch.nn.Module + @param model: The model to be used. + @type params: dict + @param params: The parameters for the optimizer. + """ + self.model = model + self.params = { + "lr": 0.02, + "momentum": 0.937, + "weight_decay": 0.0005, + "nesterov": True, + } + if params: + self.params.update(params) + + def create_optimizer(self): + batch_norm_weights, regular_weights, biases = [], [], [] + + for module in self.model.modules(): + if hasattr(module, "bias") and isinstance( + module.bias, torch.nn.Parameter + ): + biases.append(module.bias) + if isinstance(module, torch.nn.BatchNorm2d): + batch_norm_weights.append(module.weight) + elif hasattr(module, "weight") and isinstance( + module.weight, torch.nn.Parameter + ): + regular_weights.append(module.weight) + + optimizer = SGD( + [ + { + "params": batch_norm_weights, + "lr": self.params["lr"], + "momentum": self.params["momentum"], + "nesterov": self.params["nesterov"], + }, + { + "params": regular_weights, + "weight_decay": self.params["weight_decay"], + }, + {"params": biases}, + ], + lr=self.params["lr"], + momentum=self.params["momentum"], + nesterov=self.params["nesterov"], + ) + + return optimizer + + +class TripleLRSGDStrategy(BaseTrainingStrategy): + def __init__(self, pl_module: pl.LightningModule, params: dict): + """TripleLRSGD strategy. + + @type pl_module: pl.LightningModule + @param pl_module: The pl_module to be used. + @type params: dict + @param params: The parameters for the strategy. + """ + super().__init__(pl_module) + self.model = pl_module + self.params = params + self.cfg = self.model.cfg + + max_stepnum = math.ceil( + len(self.model.core.loaders["train"]) / self.cfg.trainer.batch_size + ) + + self.optimizer = TripleLRSGD(self.model, params).create_optimizer() + self.scheduler = TripleLRScheduler( + self.optimizer, params, self.cfg.trainer.epochs, max_stepnum + ) + + def configure_optimizers(self) -> tuple[list[Optimizer], list[LambdaLR]]: + return [self.optimizer], [self.scheduler.create_scheduler()] + + def update_parameters(self, *args, **kwargs): + current_epoch = self.model.current_epoch + self.scheduler.update_learning_rate(current_epoch) diff --git a/luxonis_train/utils/registry.py b/luxonis_train/utils/registry.py index 8044f13c..4f413c7a 100644 --- a/luxonis_train/utils/registry.py +++ b/luxonis_train/utils/registry.py @@ -35,6 +35,11 @@ SCHEDULERS: Registry[type[LRScheduler]] = Registry(name="schedulers") """Registry for all schedulers.""" +STRATEGIES: Registry[type["lt.strategies.BaseTrainingStrategy"]] = Registry( + name="strategies" +) +"""Registry for all strategies.""" + VISUALIZERS: Registry[type["lt.visualizers.BaseVisualizer"]] = Registry( "visualizers" ) diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index b29e0420..e32980f2 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -69,14 +69,15 @@ def test_predefined_models( config_file = f"configs/{config_file}.yaml" opts |= { "loader.params.dataset_name": ( - cifar10_dataset.dataset_name + cifar10_dataset.identifier if "classification" in config_file - else coco_dataset.dataset_name + else coco_dataset.identifier ), + "trainer.epochs": 1, } model = LuxonisModel(config_file, opts) model.train() - model.test() + model.test(view="train") def test_multi_input(opts: dict[str, Any], infer_path: Path): @@ -280,7 +281,7 @@ def test_smart_cfg_auto_populate( } model = LuxonisModel(config_file, opts) assert ( - model.cfg.trainer.scheduler.params["T_max"] == model.cfg.trainer.epochs + model.cfg.trainer.scheduler.params["T_max"] == model.cfg.trainer.epochs # type: ignore ) assert ( model.cfg.trainer.preprocessing.augmentations[0].params["out_width"]