From 9d5fd599a0608238b0407434a43243fe3744f785 Mon Sep 17 00:00:00 2001 From: Hanyi Zhang Date: Sat, 25 May 2024 19:05:23 +0200 Subject: [PATCH] Add finetune function --- src/membrain_seg/segmentation/cli/__init__.py | 1 + .../segmentation/cli/fine_tune_cli.py | 153 ++++++++++++ src/membrain_seg/segmentation/finetune.py | 225 ++++++++++++++++++ .../segmentation/training/optim_utils.py | 2 + 4 files changed, 381 insertions(+) create mode 100644 src/membrain_seg/segmentation/cli/fine_tune_cli.py create mode 100644 src/membrain_seg/segmentation/finetune.py diff --git a/src/membrain_seg/segmentation/cli/__init__.py b/src/membrain_seg/segmentation/cli/__init__.py index a2c078c..aad1241 100644 --- a/src/membrain_seg/segmentation/cli/__init__.py +++ b/src/membrain_seg/segmentation/cli/__init__.py @@ -2,6 +2,7 @@ # These imports are necessary to register CLI commands. Do not remove! from .cli import cli # noqa: F401 +from .fine_tune_cli import finetune # noqa: F401 from .segment_cli import segment # noqa: F401 from .ske_cli import skeletonize # noqa: F401 from .train_cli import data_dir_help, train # noqa: F401 diff --git a/src/membrain_seg/segmentation/cli/fine_tune_cli.py b/src/membrain_seg/segmentation/cli/fine_tune_cli.py new file mode 100644 index 0000000..2dceff6 --- /dev/null +++ b/src/membrain_seg/segmentation/cli/fine_tune_cli.py @@ -0,0 +1,153 @@ +from typing import List, Optional + +from typer import Option +from typing_extensions import Annotated + +from ..finetune import fine_tune as _fine_tune +from .cli import OPTION_PROMPT_KWARGS as PKWARGS +from .cli import cli + + +@cli.command(name="finetune", no_args_is_help=True) +def finetune( + pretrained_checkpoint_path: str = Option( # noqa: B008 + ..., + help="Path to the checkpoint of the pre-trained model.", + **PKWARGS, + ), + finetune_data_dir: str = Option( # noqa: B008 + ..., + help='Path to the directory containing the new data for fine-tuning. \ + Following the same required structure as the train function. \ + To learn more about the required\ + data structure, type "membrain data_structure_help"', + **PKWARGS, + ), + log_dir: str = Option( # noqa: B008 + "logs_fine_tune/", + help="Log directory path. Finetuning logs will be stored here.", + ), + batch_size: int = Option( # noqa: B008 + 2, + help="Batch size for training.", + ), + num_workers: int = Option( # noqa: B008 + 8, + help="Number of worker threads for data loading.", + ), + max_epochs: int = Option( # noqa: B008 + 100, + help="Maximum number of epochs for fine-tuning.", + ), + early_stop_threshold: float = Option( # noqa: B008 + 0.05, + help="Threshold for early stopping based on validation loss deviation.", + ), + aug_prob_to_one: bool = Option( # noqa: B008 + True, + help='Whether to augment with a probability of one. This helps with the \ + model\'s generalization,\ + but also severely increases training time.\ + Pass "True" or "False".', + ), + use_surface_dice: bool = Option( # noqa: B008 + False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + ), + surface_dice_weight: float = Option( # noqa: B008 + 1.0, help="Scaling factor for the Surface-Dice loss. " + ), + surface_dice_tokens: Annotated[ + Optional[List[str]], + Option( + help='List of tokens to \ + use for the Surface-Dice loss. \ + Pass tokens separately:\ + For example, train_advanced --surface_dice_tokens "ds1" \ + --surface_dice_tokens "ds2"' + ), + ] = None, + use_deep_supervision: bool = Option( # noqa: B008 + True, help='Whether to use deep supervision. Pass "True" or "False".' + ), + project_name: str = Option( # noqa: B008 + "membrain-seg_v0_finetune", + help="Project name. This helps to find your model again.", + ), + sub_name: str = Option( # noqa: B008 + "1", + help="Subproject name. For multiple runs in the same project,\ + please specify sub_names.", + ), +): + """ + Initiates fine-tuning of a pre-trained model on new datasets + and validation on original datasets. + + This function finetunes a pre-trained U-Net model on new data provided by the user. + The `finetune_data_dir` should contain the following directories: + - `imagesTr` and `labelsTr` for the user's own new training data. + - `imagesVal` and `labelsVal` for the old data, which will be used + for validation to ensure that the fine-tuned model's performance + is not significantly worse on the original training data than the + pre-trained model. + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint file of the pre-trained model. + finetune_data_dir : str + Directory containing the new dataset for fine-tuning, + structured as per the MemBrain's requirement. + Use "membrain data_structure_help" for detailed information + on the required data structure. + log_dir : str + Path to the directory where logs will be stored, by default 'logs_fine_tune/'. + batch_size : int + Number of samples per batch, by default 2. + num_workers : int + Number of worker threads for data loading, by default 8. + max_epochs : int + Maximum number of fine-tuning epochs, by default 100. + early_stop_threshold : float + Threshold for early stopping based on validation loss deviation, + by default 0.05. + aug_prob_to_one : bool + Determines whether to apply very strong data augmentation, by default True. + If set to False, data augmentation still happens, but not as frequently. + More data augmentation can lead to better performance, but also increases the + training time substantially. + use_surface_dice : bool + Determines whether to use Surface-Dice loss, by default False. + surface_dice_weight : float + Scaling factor for the Surface-Dice loss, by default 1.0. + surface_dice_tokens : list + List of tokens to use for the Surface-Dice loss. + use_deep_supervision : bool + Determines whether to use deep supervision, by default True. + project_name : str + Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'. + sub_name : str + Sub-name for the project, by default '1'. + + Note + ---- + This command configures and executes a fine-tuning session + using the provided model checkpoint. + The actual fine-tuning logic resides in the function '_fine_tune'. + """ + _fine_tune( + pretrained_checkpoint_path=pretrained_checkpoint_path, + finetune_data_dir=finetune_data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + early_stop_threshold=early_stop_threshold, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surface_dice, + surf_dice_weight=surface_dice_weight, + surf_dice_tokens=surface_dice_tokens, + project_name=project_name, + sub_name=sub_name, + ) diff --git a/src/membrain_seg/segmentation/finetune.py b/src/membrain_seg/segmentation/finetune.py new file mode 100644 index 0000000..c6e852f --- /dev/null +++ b/src/membrain_seg/segmentation/finetune.py @@ -0,0 +1,225 @@ +from typing import Optional + +import pytorch_lightning as pl +import torch +from pytorch_lightning import Callback +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint + +from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import ( + MemBrainSegDataModule, +) +from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.training_param_summary import ( + print_training_parameters, +) + + +def fine_tune( + pretrained_checkpoint_path: str, + finetune_data_dir: str, + finetune_learning_rate: float = 1e-5, + log_dir: str = "logs_finetune/", + batch_size: int = 2, + num_workers: int = 8, + max_epochs: int = 100, + early_stop_threshold: float = 0.05, + aug_prob_to_one: bool = False, + use_deep_supervision: bool = False, + project_name: str = "membrain-seg_finetune", + sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, +) -> None: + """ + Fine-tune a pre-trained U-Net model on new datasets. + + This function finetunes a pre-trained U-Net model on new data provided by the user. + The `finetune_data_dir` should contain the following directories: + - `imagesTr` and `labelsTr` for the user's own new training data. + - `imagesVal` and `labelsVal` for the old data, which will be used + for validation to ensure that the fine-tuned model's performance + is not significantly worse on the original training data than the + pre-trained model. + + Callbacks used during the fine-tuning process + --------- + - ModelCheckpoint: Saves the model checkpoints based on training loss + and at regular intervals. + - ToleranceCallback: Stops training if the validation loss deviates significantly + from the baseline value set after the first epoch. + - LearningRateMonitor: Monitors and logs the learning rate during training. + - PrintLearningRate: Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint of the pre-trained model. + finetune_data_dir : str + Path to the directory containing the new data for fine-tuning + and old data for validation. + finetune_learning_rate : float, optional + Learning rate for fine-tuning, by default 1e-5. + log_dir : str, optional + Path to the directory where logs should be stored. + batch_size : int, optional + Number of samples per batch of input data. + num_workers : int, optional + Number of subprocesses to use for data loading. + max_epochs : int, optional + Maximum number of epochs to finetune, by default 100. + early_stop_threshold : float, optional + Threshold for early stopping based on validation loss deviation, + by default 0.05. + aug_prob_to_one : bool, optional + If True, all augmentation probabilities are set to 1. + use_deep_supervision : bool, optional + If True, enables deep supervision in the U-Net model. + project_name : str, optional + Name of the project for logging purposes. + sub_name : str, optional + Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. + + Returns + ------- + None + """ + # Print training parameters for verification + print_training_parameters( + data_dir=finetune_data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, + ) + print("————————————————————————————————————————————————————————") + print( + f"Pretrained Checkpoint:\n" + f" '{pretrained_checkpoint_path}' \n" + f" Path to the pretrained model checkpoint." + ) + print("\n") + + # Initialize the data module with fine-tuning datasets + # New data for finetuning and old data for validation + finetune_data_module = MemBrainSegDataModule( + data_dir=finetune_data_dir, + batch_size=batch_size, + num_workers=num_workers, + aug_prob_to_one=aug_prob_to_one, + ) + + # Load the pre-trained model with updated learning rate + pretrained_model = SemanticSegmentationUnet.load_from_checkpoint( + pretrained_checkpoint_path, learning_rate=finetune_learning_rate + ) + + checkpointing_name = project_name + "_" + sub_name + + # Set up logging + csv_logger = pl_loggers.CSVLogger(log_dir) + + # Set up model checkpointing based on training loss + checkpoint_callback_train_loss = ModelCheckpoint( + dirpath="finetuned_checkpoints/", + filename=checkpointing_name + "-{epoch:02d}-{train_loss:.2f}", + monitor="train_loss", + mode="min", + save_top_k=3, + ) + + # Set up regular checkpointing every 5 epochs + checkpoint_callback_regular = ModelCheckpoint( + save_top_k=-1, # Save all checkpoints + every_n_epochs=5, + dirpath="finetuned_checkpoints/", + filename=checkpointing_name + "-{epoch}-{train_loss:.2f}", + verbose=True, # Print a message when a checkpoint is saved + ) + + class ToleranceCallback(Callback): + """ + Callback to stop training if the monitored metric deviates + beyond a certain threshold from the baseline value obtained + after the first epoch. + """ + + def __init__(self, metric_name: str, threshold: float): + super().__init__() + self.metric_name = metric_name + self.threshold = threshold + self.baseline_value: Optional[float] = ( + None # Baseline value will be set after the first epoch + ) + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + # Access the metric value from the validation metrics + metric_value = trainer.callback_metrics.get(self.metric_name) + + # If the metric value is a tensor, convert it to a float + if isinstance(metric_value, torch.Tensor): + metric_value = metric_value.item() + + # Set the baseline value after the first validation epoch + if metric_value is not None: + if self.baseline_value is None: + self.baseline_value = metric_value + print(f"Baseline {self.metric_name} set to {self.baseline_value}") + return [] + + # Check if the metric value deviates beyond the threshold + if abs(metric_value - self.baseline_value) > self.threshold: + print( + f"Stopping training as {self.metric_name} " + f"deviates too far from the baseline value." + ) + trainer.should_stop = True + + early_stop_metric = "val_loss" + + tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold) + + # Monitor learning rate changes + lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) + + class PrintLearningRate(Callback): + """Callback to print the current learning rate at the start of each epoch.""" + + def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + current_lr = trainer.optimizers[0].param_groups[0]["lr"] + print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") + + print_lr_cb = PrintLearningRate() + + # Initialize the trainer with specified precision, logger, and callbacks + trainer = pl.Trainer( + precision="16-mixed", + logger=[csv_logger], + callbacks=[ + checkpoint_callback_train_loss, + checkpoint_callback_regular, + lr_monitor, + print_lr_cb, + tolerance_callback, + ], + max_epochs=max_epochs, + ) + + # Start the fine-tuning process + trainer.fit(pretrained_model, finetune_data_module) diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index a16e136..9ca28c1 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -115,6 +115,8 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: combined_loss = combined_loss.mean() elif self.reduction == "sum": combined_loss = combined_loss.sum() + elif self.reduction == "none": + return combined_loss else: raise ValueError( f"Invalid reduction type {self.reduction}. "