From 85261ec3df6c7456c837ae1df716eef6af906fea Mon Sep 17 00:00:00 2001 From: Hanyi Zhang Date: Mon, 3 Jun 2024 15:52:16 +0200 Subject: [PATCH] Modify finetune function --- .../segmentation/cli/fine_tune_cli.py | 85 ++++++++++++- src/membrain_seg/segmentation/finetune.py | 58 +-------- src/membrain_seg/segmentation/train.py | 8 +- .../segmentation/training/optim_utils.py | 113 ++++++++++++++++++ 4 files changed, 205 insertions(+), 59 deletions(-) diff --git a/src/membrain_seg/segmentation/cli/fine_tune_cli.py b/src/membrain_seg/segmentation/cli/fine_tune_cli.py index 2dceff6..a9b601a 100644 --- a/src/membrain_seg/segmentation/cli/fine_tune_cli.py +++ b/src/membrain_seg/segmentation/cli/fine_tune_cli.py @@ -23,6 +23,83 @@ def finetune( data structure, type "membrain data_structure_help"', **PKWARGS, ), +): + """ + Initiates fine-tuning of a pre-trained model on new datasets + and validation on original datasets. + + This function fine-tunes a pre-trained model on new datasets provided by the user. + The directory specified by `finetune_data_dir` should be structured according to the + requirements for the training function. + For more details, use "membrain data_structure_help". + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint file of the pre-trained model. + finetune_data_dir : str + Directory containing the new dataset for fine-tuning, + structured as per the MemBrain's requirement. + Use "membrain data_structure_help" for detailed information + on the required data structure. + + Note + + ---- + + This command configures and executes a fine-tuning session + using the provided model checkpoint. + The actual fine-tuning logic resides in the function '_fine_tune'. + """ + finetune_learning_rate = 1e-5 + log_dir = "logs_finetune/" + batch_size = 2 + num_workers = 8 + max_epochs = 100 + early_stop_threshold = 0.05 + aug_prob_to_one = True + use_deep_supervision = True + project_name = "membrain-seg_finetune" + sub_name = "1" + + _fine_tune( + pretrained_checkpoint_path=pretrained_checkpoint_path, + finetune_data_dir=finetune_data_dir, + finetune_learning_rate=finetune_learning_rate, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + early_stop_threshold=early_stop_threshold, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + ) + + +@cli.command(name="finetune_advanced", no_args_is_help=True) +def finetune_advanced( + pretrained_checkpoint_path: str = Option( # noqa: B008 + ..., + help="Path to the checkpoint of the pre-trained model.", + **PKWARGS, + ), + finetune_data_dir: str = Option( # noqa: B008 + ..., + help='Path to the directory containing the new data for fine-tuning. \ + Following the same required structure as the train function. \ + To learn more about the required\ + data structure, type "membrain data_structure_help"', + **PKWARGS, + ), + finetune_learning_rate: float = Option( # noqa: B008 + 1e-5, + help="Learning rate for fine-tuning the model. This parameter controls the \ + step size at each iteration while moving toward a minimum loss. \ + A smaller learning rate can lead to a more precise convergence but may \ + require more epochs. Adjust based on your dataset size and complexity.", + ), log_dir: str = Option( # noqa: B008 "logs_fine_tune/", help="Log directory path. Finetuning logs will be stored here.", @@ -81,7 +158,7 @@ def finetune( ): """ Initiates fine-tuning of a pre-trained model on new datasets - and validation on original datasets. + and validation on original datasets with more advanced options. This function finetunes a pre-trained U-Net model on new data provided by the user. The `finetune_data_dir` should contain the following directories: @@ -100,6 +177,11 @@ def finetune( structured as per the MemBrain's requirement. Use "membrain data_structure_help" for detailed information on the required data structure. + finetune_learning_rate : float + Learning rate for fine-tuning the model. This parameter controls the step size + at each iteration while moving toward a minimum loss. A smaller learning rate + can lead to a more precise convergence but may require more epochs. + Adjust based on your dataset size and complexity. log_dir : str Path to the directory where logs will be stored, by default 'logs_fine_tune/'. batch_size : int @@ -138,6 +220,7 @@ def finetune( _fine_tune( pretrained_checkpoint_path=pretrained_checkpoint_path, finetune_data_dir=finetune_data_dir, + finetune_learning_rate=finetune_learning_rate, log_dir=log_dir, batch_size=batch_size, num_workers=num_workers, diff --git a/src/membrain_seg/segmentation/finetune.py b/src/membrain_seg/segmentation/finetune.py index c6e852f..b440ada 100644 --- a/src/membrain_seg/segmentation/finetune.py +++ b/src/membrain_seg/segmentation/finetune.py @@ -1,8 +1,4 @@ -from typing import Optional - import pytorch_lightning as pl -import torch -from pytorch_lightning import Callback from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -10,6 +6,10 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.optim_utils import ( + PrintLearningRate, + ToleranceCallback, +) from membrain_seg.segmentation.training.training_param_summary import ( print_training_parameters, ) @@ -151,60 +151,14 @@ def fine_tune( verbose=True, # Print a message when a checkpoint is saved ) - class ToleranceCallback(Callback): - """ - Callback to stop training if the monitored metric deviates - beyond a certain threshold from the baseline value obtained - after the first epoch. - """ - - def __init__(self, metric_name: str, threshold: float): - super().__init__() - self.metric_name = metric_name - self.threshold = threshold - self.baseline_value: Optional[float] = ( - None # Baseline value will be set after the first epoch - ) - - def on_validation_epoch_end( - self, trainer: pl.Trainer, pl_module: pl.LightningModule - ): - # Access the metric value from the validation metrics - metric_value = trainer.callback_metrics.get(self.metric_name) - - # If the metric value is a tensor, convert it to a float - if isinstance(metric_value, torch.Tensor): - metric_value = metric_value.item() - - # Set the baseline value after the first validation epoch - if metric_value is not None: - if self.baseline_value is None: - self.baseline_value = metric_value - print(f"Baseline {self.metric_name} set to {self.baseline_value}") - return [] - - # Check if the metric value deviates beyond the threshold - if abs(metric_value - self.baseline_value) > self.threshold: - print( - f"Stopping training as {self.metric_name} " - f"deviates too far from the baseline value." - ) - trainer.should_stop = True - + # Set up ToleranceCallback by monitoring validation loss early_stop_metric = "val_loss" - tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold) # Monitor learning rate changes lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) - class PrintLearningRate(Callback): - """Callback to print the current learning rate at the start of each epoch.""" - - def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - current_lr = trainer.optimizers[0].param_groups[0]["lr"] - print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") - + # Print the current learning rate at the start of each epoch print_lr_cb = PrintLearningRate() # Initialize the trainer with specified precision, logger, and callbacks diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 9c576f5..2e7253f 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -1,7 +1,6 @@ import warnings import pytorch_lightning as pl -from pytorch_lightning import Callback from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -9,6 +8,7 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.optim_utils import PrintLearningRate from membrain_seg.segmentation.training.training_param_summary import ( print_training_parameters, ) @@ -124,12 +124,8 @@ def train( lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) - class PrintLearningRate(Callback): - def on_epoch_start(self, trainer, pl_module): - current_lr = trainer.optimizers[0].param_groups[0]["lr"] - print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") - print_lr_cb = PrintLearningRate() + # Set up the trainer trainer = pl.Trainer( precision="16-mixed", diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index 9ca28c1..b56f8f0 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -1,7 +1,11 @@ +from typing import Optional + +import pytorch_lightning as pl import torch from monai.losses import DiceLoss, MaskedLoss from monai.networks.nets import DynUNet from monai.utils import LossReduction +from pytorch_lightning import Callback from torch.nn.functional import ( binary_cross_entropy_with_logits, sigmoid, @@ -268,3 +272,112 @@ def forward( # Normalize loss loss = loss / sum(self.weights) return loss + + +class PrintLearningRate(Callback): + """ + Callback to print the current learning rate at the start of each epoch. + + Methods + ------- + on_epoch_start(trainer, pl_module) + Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + + """ + + def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + """ + Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + """ + current_lr = trainer.optimizers[0].param_groups[0]["lr"] + print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") + + +class ToleranceCallback(Callback): + """ + Callback to stop training if the monitored metric deviates + beyond a certain threshold from the baseline value obtained + after the first epoch. + + Parameters + ---------- + metric_name : str + The name of the metric to monitor. + threshold : float + The threshold value for deviation from the baseline. + + Methods + ------- + on_validation_epoch_end(trainer, pl_module) + Checks if the monitored metric deviates beyond the threshold + and stops training if it does. + """ + + def __init__(self, metric_name: str, threshold: float): + """ + Initializes the ToleranceCallback with the metric + to monitor and the deviation threshold. + + Parameters + ---------- + metric_name : str + The name of the metric to monitor. + threshold : float + The threshold value for deviation from the baseline. + """ + super().__init__() + self.metric_name = metric_name + self.threshold = threshold + self.baseline_value: Optional[float] = ( + None # Baseline value will be set after the first epoch + ) + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + """ + Checks if the monitored metric deviates beyond the threshold + and stops training if it does. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + """ + # Access the metric value from the validation metrics + metric_value = trainer.callback_metrics.get(self.metric_name) + + # If the metric value is a tensor, convert it to a float + if isinstance(metric_value, torch.Tensor): + metric_value = metric_value.item() + + # Set the baseline value after the first validation epoch + if metric_value is not None: + if self.baseline_value is None: + self.baseline_value = metric_value + print(f"Baseline {self.metric_name} set to {self.baseline_value}") + return [] + + # Check if the metric value deviates beyond the threshold + if abs(metric_value - self.baseline_value) > self.threshold: + print( + f"Stopping training as {self.metric_name} " + f"deviates too far from the baseline value." + ) + trainer.should_stop = True