Skip to content

Commit

Permalink
Step based checkpointing and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Feb 11, 2025
1 parent c5093b7 commit b0648fd
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 20 deletions.
30 changes: 24 additions & 6 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.checkpointing.constants import CURR_STEP_KEY
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(self, cfg: DictConfig) -> None:

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self.save_every_n_steps = cfg.get("save_every_n_steps")
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.optimizer_in_bwd
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
Expand Down Expand Up @@ -324,6 +326,10 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# For now, default to saving at epoch boundaries
if self.save_every_n_steps is None:
self.save_every_n_steps = self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
Expand Down Expand Up @@ -596,30 +602,35 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint(self, epoch: int) -> None:
def save_checkpoint(self, *, epoch: int, step: int) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:

# If training is in-progress, checkpoint the optimizer state as well
is_intermediate = step < self._steps_per_epoch * self.total_epochs
if is_intermediate:
ckpt_dict.update(
{
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.EPOCHS_KEY: epoch,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
CURR_STEP_KEY: step,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()

self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
intermediate_checkpoint=is_intermediate,
step=step,
)

def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -753,6 +764,13 @@ def train(self) -> None:
step=self.global_step,
)

# Save checkpoint if specified by user
if (
self.global_step > 0
and self.global_step % self.save_every_n_steps == 0
):
self.save_checkpoint(epoch=curr_epoch, step=self.global_step)

# Reset running stats for the next step
running_loss = 0
num_tokens = 0
Expand All @@ -776,8 +794,8 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
self._profiler.stop()

def cleanup(self) -> None:
Expand Down
127 changes: 126 additions & 1 deletion tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# LICENSE file in the root directory of this source tree.

import os
import re

import runpy

import sys
from pathlib import Path

import pytest

import torch
from tests.common import TUNE_PATH

Expand Down Expand Up @@ -214,3 +214,128 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize("keep_last_n_checkpoints", [1, 2])
@pytest.mark.parametrize("save_every_n_steps", [1, 2])
def test_checkpointing_with_steps(
self, tmpdir, monkeypatch, keep_last_n_checkpoints, save_every_n_steps
):
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs (anywhere from 2 -> 4 ckpts)
cmd_1 = f"""
tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
batch_size=8 \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
checkpointer.keep_last_n_checkpoints={keep_last_n_checkpoints} \
save_every_n_steps={save_every_n_steps} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
model_config = MODEL_TEST_CONFIGS["llama2"]
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_1)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

regex_to_match = re.compile("step_([0-9]+)")
# Iterate over the directory contents, find all directories that match
# `regex_to_match`. Assert that the number of directories found is equal
# to the `keep_last_n_checkpoints` value. Also assert that each checkpoint
# number is a multiple of `save_every_n_steps`.
ckpt_dirs = [
d
for d in os.listdir(tmpdir)
if os.path.isdir(os.path.join(tmpdir, d)) and regex_to_match.match(d)
]
assert len(ckpt_dirs) == keep_last_n_checkpoints
for ckpt_dir in ckpt_dirs:
step = int(regex_to_match.match(ckpt_dir).group(1))
assert step % save_every_n_steps == 0

# Also make sure that the last checkpoint has the correct number of steps
most_recent_checkpoint = get_largest_iter_folder(tmpdir, pattern=r"^step_(\d+)")
step = int(regex_to_match.match(most_recent_checkpoint).group(1))
assert step == 4 # 2 epochs * 2 steps per epoch

@pytest.mark.integration_test
def test_checkpointing_with_steps_and_resume(self, tmpdir, monkeypatch):
"""We want to be sure that now we use steps, we can resume correctly from a checkpoint.
Once we fully transition to steps, we can remove the test above."""
# 0. Set up variables
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# 1. Train for two epochs, keep 2 checkpoints
cmd_1 = f"""
tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
batch_size=8 \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
checkpointer.keep_last_n_checkpoints=2 \
save_every_n_steps=2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
model_config = MODEL_TEST_CONFIGS["llama2"]
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_1)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# 2. Find the checkpoint at the end of the first epoch
step_folder = get_largest_iter_folder(tmpdir, pattern=r"^step_(\d+)")
step_folder_at_epoch_boundary = f"step_{int(step_folder.split('_')[-1]) - 2}"
suffix = ".safetensors"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)

# 3. Resume training w/ the checkpoint from epoch boundary
cmd_2 = f"""
tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
batch_size=8 \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{os.path.join(step_folder_at_epoch_boundary, model_ckpt_fname)}]\
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# 4. Make sure loss values match the expected values
expected_loss_values = self._fetch_expected_loss_values("llama2")[2:]
loss_values = get_loss_values_from_metric_logger(log_file)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)
43 changes: 35 additions & 8 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
check_outdir_not_in_ckptdir,
copy_files,
get_adapter_checkpoint_path,
get_all_checkpoints_in_dir,
get_model_checkpoint_path,
get_recipe_checkpoint_path,
ModelType,
prune_surplus_checkpoints,
RECIPE_STATE_DIRNAME,
REPO_ID_FNAME,
safe_torch_load,
Expand Down Expand Up @@ -399,6 +401,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
Default is True.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False
keep_last_n_checkpoints (Optional[int]): How many checkpoints to keep. If None, all checkpoints are kept.
"""

def __init__(
Expand All @@ -412,6 +415,8 @@ def __init__(
resume_from_checkpoint: bool = False,
safe_serialization: bool = True,
should_load_recipe_state: bool = False,
*,
keep_last_n_checkpoints: Optional[int] = None,
) -> None:

self._should_load_recipe_state = should_load_recipe_state
Expand All @@ -420,6 +425,7 @@ def __init__(
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)
self._keep_last_n_checkpoints = keep_last_n_checkpoints

self._safe_serialization = safe_serialization
self._checkpoint_dir = Path(checkpoint_dir)
Expand Down Expand Up @@ -457,7 +463,7 @@ def __init__(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
pattern=r"^step_(\d+)",
)

# resume recipe_state ckpt
Expand Down Expand Up @@ -629,6 +635,8 @@ def save_checkpoint(
epoch: int,
intermediate_checkpoint: bool = False,
adapter_only: bool = False,
*,
step: Optional[int] = None,
) -> None:
"""
Save HF checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
Expand All @@ -644,10 +652,19 @@ def save_checkpoint(
intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state
and (if applicable) adapter weights are created. Default is False
adapter_only (bool): If True, only save the adapter weights. Default is False
step (Optional[int]): Step number. Used to create the checkpoint file name if provided.
Raises:
ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict.
"""
# Prefer to use step, not epoch
if step is not None:
ckpt_save_dirname = f"step_{step}"
ckpt_pattern = r"^step_(\d+)"
else:
ckpt_save_dirname = f"epoch_{epoch}"
ckpt_pattern = r"^epoch_(\d+)"

# convert the state_dict back to hf format; do this inplace
if not adapter_only:
if self._model_type == ModelType.PHI3_MINI:
Expand Down Expand Up @@ -747,7 +764,7 @@ def save_checkpoint(
)
map_original_name_to_new_name[cpt_idx] = shard_name
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", shard_name
self._output_dir, ckpt_save_dirname, shard_name
)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not self._safe_serialization:
Expand Down Expand Up @@ -779,7 +796,7 @@ def save_checkpoint(
index_file_name = TORCH_INDEX_FNAME

index_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", index_file_name
self._output_dir, ckpt_save_dirname, index_file_name
)

index_data = {
Expand All @@ -796,7 +813,7 @@ def save_checkpoint(
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
self._output_dir, ckpt_save_dirname, ADAPTER_MODEL_FNAME
).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
Expand Down Expand Up @@ -825,7 +842,7 @@ def save_checkpoint(
head_dim=self._config.get("head_dim", None),
)
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
self._output_dir, ckpt_save_dirname, ADAPTER_MODEL_FNAME
)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not self._safe_serialization:
Expand Down Expand Up @@ -866,7 +883,7 @@ def save_checkpoint(
)

output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", ADAPTER_CONFIG_FNAME
self._output_dir, ckpt_save_dirname, ADAPTER_CONFIG_FNAME
).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
Expand All @@ -880,7 +897,7 @@ def save_checkpoint(
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir,
Path.joinpath(self._output_dir, f"epoch_{epoch}"),
Path.joinpath(self._output_dir, ckpt_save_dirname),
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)

Expand All @@ -901,7 +918,7 @@ def save_checkpoint(
f"saved to {output_path}"
)
else:
logger.info("Saving final epoch checkpoint.")
logger.info("Saving final checkpoint.")
if adapter_only:
logger.info(
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
Expand All @@ -914,6 +931,16 @@ def save_checkpoint(
"You can now use this checkpoint for further training or inference."
)

# If specified, prune the checkpoints in the output directory
if self._keep_last_n_checkpoints is not None:
all_current_checkpoints = get_all_checkpoints_in_dir(
self._output_dir, pattern=ckpt_pattern
)
prune_surplus_checkpoints(
all_current_checkpoints,
keep_last_n_checkpoints=self._keep_last_n_checkpoints,
)


class FullModelMetaCheckpointer(_CheckpointerInterface):
"""
Expand Down
Loading

0 comments on commit b0648fd

Please sign in to comment.