-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Hanyi Zhang
committed
May 25, 2024
1 parent
d005be5
commit 9d5fd59
Showing
4 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from typing import List, Optional | ||
|
||
from typer import Option | ||
from typing_extensions import Annotated | ||
|
||
from ..finetune import fine_tune as _fine_tune | ||
from .cli import OPTION_PROMPT_KWARGS as PKWARGS | ||
from .cli import cli | ||
|
||
|
||
@cli.command(name="finetune", no_args_is_help=True) | ||
def finetune( | ||
pretrained_checkpoint_path: str = Option( # noqa: B008 | ||
..., | ||
help="Path to the checkpoint of the pre-trained model.", | ||
**PKWARGS, | ||
), | ||
finetune_data_dir: str = Option( # noqa: B008 | ||
..., | ||
help='Path to the directory containing the new data for fine-tuning. \ | ||
Following the same required structure as the train function. \ | ||
To learn more about the required\ | ||
data structure, type "membrain data_structure_help"', | ||
**PKWARGS, | ||
), | ||
log_dir: str = Option( # noqa: B008 | ||
"logs_fine_tune/", | ||
help="Log directory path. Finetuning logs will be stored here.", | ||
), | ||
batch_size: int = Option( # noqa: B008 | ||
2, | ||
help="Batch size for training.", | ||
), | ||
num_workers: int = Option( # noqa: B008 | ||
8, | ||
help="Number of worker threads for data loading.", | ||
), | ||
max_epochs: int = Option( # noqa: B008 | ||
100, | ||
help="Maximum number of epochs for fine-tuning.", | ||
), | ||
early_stop_threshold: float = Option( # noqa: B008 | ||
0.05, | ||
help="Threshold for early stopping based on validation loss deviation.", | ||
), | ||
aug_prob_to_one: bool = Option( # noqa: B008 | ||
True, | ||
help='Whether to augment with a probability of one. This helps with the \ | ||
model\'s generalization,\ | ||
but also severely increases training time.\ | ||
Pass "True" or "False".', | ||
), | ||
use_surface_dice: bool = Option( # noqa: B008 | ||
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' | ||
), | ||
surface_dice_weight: float = Option( # noqa: B008 | ||
1.0, help="Scaling factor for the Surface-Dice loss. " | ||
), | ||
surface_dice_tokens: Annotated[ | ||
Optional[List[str]], | ||
Option( | ||
help='List of tokens to \ | ||
use for the Surface-Dice loss. \ | ||
Pass tokens separately:\ | ||
For example, train_advanced --surface_dice_tokens "ds1" \ | ||
--surface_dice_tokens "ds2"' | ||
), | ||
] = None, | ||
use_deep_supervision: bool = Option( # noqa: B008 | ||
True, help='Whether to use deep supervision. Pass "True" or "False".' | ||
), | ||
project_name: str = Option( # noqa: B008 | ||
"membrain-seg_v0_finetune", | ||
help="Project name. This helps to find your model again.", | ||
), | ||
sub_name: str = Option( # noqa: B008 | ||
"1", | ||
help="Subproject name. For multiple runs in the same project,\ | ||
please specify sub_names.", | ||
), | ||
): | ||
""" | ||
Initiates fine-tuning of a pre-trained model on new datasets | ||
and validation on original datasets. | ||
This function finetunes a pre-trained U-Net model on new data provided by the user. | ||
The `finetune_data_dir` should contain the following directories: | ||
- `imagesTr` and `labelsTr` for the user's own new training data. | ||
- `imagesVal` and `labelsVal` for the old data, which will be used | ||
for validation to ensure that the fine-tuned model's performance | ||
is not significantly worse on the original training data than the | ||
pre-trained model. | ||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint file of the pre-trained model. | ||
finetune_data_dir : str | ||
Directory containing the new dataset for fine-tuning, | ||
structured as per the MemBrain's requirement. | ||
Use "membrain data_structure_help" for detailed information | ||
on the required data structure. | ||
log_dir : str | ||
Path to the directory where logs will be stored, by default 'logs_fine_tune/'. | ||
batch_size : int | ||
Number of samples per batch, by default 2. | ||
num_workers : int | ||
Number of worker threads for data loading, by default 8. | ||
max_epochs : int | ||
Maximum number of fine-tuning epochs, by default 100. | ||
early_stop_threshold : float | ||
Threshold for early stopping based on validation loss deviation, | ||
by default 0.05. | ||
aug_prob_to_one : bool | ||
Determines whether to apply very strong data augmentation, by default True. | ||
If set to False, data augmentation still happens, but not as frequently. | ||
More data augmentation can lead to better performance, but also increases the | ||
training time substantially. | ||
use_surface_dice : bool | ||
Determines whether to use Surface-Dice loss, by default False. | ||
surface_dice_weight : float | ||
Scaling factor for the Surface-Dice loss, by default 1.0. | ||
surface_dice_tokens : list | ||
List of tokens to use for the Surface-Dice loss. | ||
use_deep_supervision : bool | ||
Determines whether to use deep supervision, by default True. | ||
project_name : str | ||
Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'. | ||
sub_name : str | ||
Sub-name for the project, by default '1'. | ||
Note | ||
---- | ||
This command configures and executes a fine-tuning session | ||
using the provided model checkpoint. | ||
The actual fine-tuning logic resides in the function '_fine_tune'. | ||
""" | ||
_fine_tune( | ||
pretrained_checkpoint_path=pretrained_checkpoint_path, | ||
finetune_data_dir=finetune_data_dir, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
early_stop_threshold=early_stop_threshold, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
use_surf_dice=use_surface_dice, | ||
surf_dice_weight=surface_dice_weight, | ||
surf_dice_tokens=surface_dice_tokens, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
from typing import Optional | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pytorch_lightning import Callback | ||
from pytorch_lightning import loggers as pl_loggers | ||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | ||
|
||
from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import ( | ||
MemBrainSegDataModule, | ||
) | ||
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet | ||
from membrain_seg.segmentation.training.training_param_summary import ( | ||
print_training_parameters, | ||
) | ||
|
||
|
||
def fine_tune( | ||
pretrained_checkpoint_path: str, | ||
finetune_data_dir: str, | ||
finetune_learning_rate: float = 1e-5, | ||
log_dir: str = "logs_finetune/", | ||
batch_size: int = 2, | ||
num_workers: int = 8, | ||
max_epochs: int = 100, | ||
early_stop_threshold: float = 0.05, | ||
aug_prob_to_one: bool = False, | ||
use_deep_supervision: bool = False, | ||
project_name: str = "membrain-seg_finetune", | ||
sub_name: str = "1", | ||
use_surf_dice: bool = False, | ||
surf_dice_weight: float = 1.0, | ||
surf_dice_tokens: list = None, | ||
) -> None: | ||
""" | ||
Fine-tune a pre-trained U-Net model on new datasets. | ||
This function finetunes a pre-trained U-Net model on new data provided by the user. | ||
The `finetune_data_dir` should contain the following directories: | ||
- `imagesTr` and `labelsTr` for the user's own new training data. | ||
- `imagesVal` and `labelsVal` for the old data, which will be used | ||
for validation to ensure that the fine-tuned model's performance | ||
is not significantly worse on the original training data than the | ||
pre-trained model. | ||
Callbacks used during the fine-tuning process | ||
--------- | ||
- ModelCheckpoint: Saves the model checkpoints based on training loss | ||
and at regular intervals. | ||
- ToleranceCallback: Stops training if the validation loss deviates significantly | ||
from the baseline value set after the first epoch. | ||
- LearningRateMonitor: Monitors and logs the learning rate during training. | ||
- PrintLearningRate: Prints the current learning rate at the start of each epoch. | ||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint of the pre-trained model. | ||
finetune_data_dir : str | ||
Path to the directory containing the new data for fine-tuning | ||
and old data for validation. | ||
finetune_learning_rate : float, optional | ||
Learning rate for fine-tuning, by default 1e-5. | ||
log_dir : str, optional | ||
Path to the directory where logs should be stored. | ||
batch_size : int, optional | ||
Number of samples per batch of input data. | ||
num_workers : int, optional | ||
Number of subprocesses to use for data loading. | ||
max_epochs : int, optional | ||
Maximum number of epochs to finetune, by default 100. | ||
early_stop_threshold : float, optional | ||
Threshold for early stopping based on validation loss deviation, | ||
by default 0.05. | ||
aug_prob_to_one : bool, optional | ||
If True, all augmentation probabilities are set to 1. | ||
use_deep_supervision : bool, optional | ||
If True, enables deep supervision in the U-Net model. | ||
project_name : str, optional | ||
Name of the project for logging purposes. | ||
sub_name : str, optional | ||
Sub-name of the project for logging purposes. | ||
use_surf_dice : bool, optional | ||
If True, enables Surface-Dice loss. | ||
surf_dice_weight : float, optional | ||
Weight for the Surface-Dice loss. | ||
surf_dice_tokens : list, optional | ||
List of tokens to use for the Surface-Dice loss. | ||
Returns | ||
------- | ||
None | ||
""" | ||
# Print training parameters for verification | ||
print_training_parameters( | ||
data_dir=finetune_data_dir, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
use_surf_dice=use_surf_dice, | ||
surf_dice_weight=surf_dice_weight, | ||
surf_dice_tokens=surf_dice_tokens, | ||
) | ||
print("————————————————————————————————————————————————————————") | ||
print( | ||
f"Pretrained Checkpoint:\n" | ||
f" '{pretrained_checkpoint_path}' \n" | ||
f" Path to the pretrained model checkpoint." | ||
) | ||
print("\n") | ||
|
||
# Initialize the data module with fine-tuning datasets | ||
# New data for finetuning and old data for validation | ||
finetune_data_module = MemBrainSegDataModule( | ||
data_dir=finetune_data_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
aug_prob_to_one=aug_prob_to_one, | ||
) | ||
|
||
# Load the pre-trained model with updated learning rate | ||
pretrained_model = SemanticSegmentationUnet.load_from_checkpoint( | ||
pretrained_checkpoint_path, learning_rate=finetune_learning_rate | ||
) | ||
|
||
checkpointing_name = project_name + "_" + sub_name | ||
|
||
# Set up logging | ||
csv_logger = pl_loggers.CSVLogger(log_dir) | ||
|
||
# Set up model checkpointing based on training loss | ||
checkpoint_callback_train_loss = ModelCheckpoint( | ||
dirpath="finetuned_checkpoints/", | ||
filename=checkpointing_name + "-{epoch:02d}-{train_loss:.2f}", | ||
monitor="train_loss", | ||
mode="min", | ||
save_top_k=3, | ||
) | ||
|
||
# Set up regular checkpointing every 5 epochs | ||
checkpoint_callback_regular = ModelCheckpoint( | ||
save_top_k=-1, # Save all checkpoints | ||
every_n_epochs=5, | ||
dirpath="finetuned_checkpoints/", | ||
filename=checkpointing_name + "-{epoch}-{train_loss:.2f}", | ||
verbose=True, # Print a message when a checkpoint is saved | ||
) | ||
|
||
class ToleranceCallback(Callback): | ||
""" | ||
Callback to stop training if the monitored metric deviates | ||
beyond a certain threshold from the baseline value obtained | ||
after the first epoch. | ||
""" | ||
|
||
def __init__(self, metric_name: str, threshold: float): | ||
super().__init__() | ||
self.metric_name = metric_name | ||
self.threshold = threshold | ||
self.baseline_value: Optional[float] = ( | ||
None # Baseline value will be set after the first epoch | ||
) | ||
|
||
def on_validation_epoch_end( | ||
self, trainer: pl.Trainer, pl_module: pl.LightningModule | ||
): | ||
# Access the metric value from the validation metrics | ||
metric_value = trainer.callback_metrics.get(self.metric_name) | ||
|
||
# If the metric value is a tensor, convert it to a float | ||
if isinstance(metric_value, torch.Tensor): | ||
metric_value = metric_value.item() | ||
|
||
# Set the baseline value after the first validation epoch | ||
if metric_value is not None: | ||
if self.baseline_value is None: | ||
self.baseline_value = metric_value | ||
print(f"Baseline {self.metric_name} set to {self.baseline_value}") | ||
return [] | ||
|
||
# Check if the metric value deviates beyond the threshold | ||
if abs(metric_value - self.baseline_value) > self.threshold: | ||
print( | ||
f"Stopping training as {self.metric_name} " | ||
f"deviates too far from the baseline value." | ||
) | ||
trainer.should_stop = True | ||
|
||
early_stop_metric = "val_loss" | ||
|
||
tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold) | ||
|
||
# Monitor learning rate changes | ||
lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) | ||
|
||
class PrintLearningRate(Callback): | ||
"""Callback to print the current learning rate at the start of each epoch.""" | ||
|
||
def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
current_lr = trainer.optimizers[0].param_groups[0]["lr"] | ||
print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") | ||
|
||
print_lr_cb = PrintLearningRate() | ||
|
||
# Initialize the trainer with specified precision, logger, and callbacks | ||
trainer = pl.Trainer( | ||
precision="16-mixed", | ||
logger=[csv_logger], | ||
callbacks=[ | ||
checkpoint_callback_train_loss, | ||
checkpoint_callback_regular, | ||
lr_monitor, | ||
print_lr_cb, | ||
tolerance_callback, | ||
], | ||
max_epochs=max_epochs, | ||
) | ||
|
||
# Start the fine-tuning process | ||
trainer.fit(pretrained_model, finetune_data_module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters