Skip to content

Commit

Permalink
Add helper functions for pruning old checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Feb 10, 2025
1 parent 57f7cc1 commit c5093b7
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tests/torchtune/training/checkpointing/test_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torchtune.training.checkpointing._utils import (
check_outdir_not_in_ckptdir,
FormattedCheckpointFiles,
get_all_checkpoints_in_dir,
prune_surplus_checkpoints,
safe_torch_load,
update_state_dict_for_classifier,
)
Expand Down Expand Up @@ -271,3 +273,93 @@ def test_output_dir_ckpt_dir_few_levels_down(self):
match="The output directory cannot be the same as or a subdirectory of the checkpoint directory.",
):
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)


class TestGetAllCheckpointsInDir:
"""Series of tests for the ``get_all_checkpoints_in_dir`` function."""

def test_get_all_ckpts_simple(self, tmp_dir):
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
assert len(all_ckpts) == 2
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]

def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir):
"""Test that we only return the checkpoints that match the pattern."""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "step_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_0]

def test_get_all_ckpts_override_pattern(self, tmp_dir):
"""Test that we can override the default pattern and it works."""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "step_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*")
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_1]

def test_get_all_ckpts_only_return_dirs(self, tmp_dir):
"""Test that even if a file matches the pattern, we only return directories."""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

file = tmp_dir / "epoch_1"
ckpt_dir_1.touch()

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_0]

def test_get_all_ckpts_non_unique(self, tmp_dir):
"""Test that we return all checkpoints, even if they have the same name."""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_0"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
assert len(all_ckpts) == 2
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]


class TestPruneSurplusCheckpoints:
"""Series of tests for the ``prune_surplus_checkpoints`` function."""

def test_prune_surplus_checkpoints_simple(self, tmp_dir):
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

prune_surplus_checkpoints(tmp_dir, 1)
remaining_ckpts = os.listdir(tmp_dir)
assert len(remaining_ckpts) == 1
assert remaining_ckpts == ["epoch_1"]

def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir):
"""Test that we raise an error if keep_last_n_checkpoints is not >= 1"""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"):
prune_surplus_checkpoints(tmp_dir, 0)
70 changes: 70 additions & 0 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,73 @@ def check_outdir_not_in_ckptdir(ckpt_dir: Path, out_dir: Path) -> bool:
)

return True


def get_all_checkpoints_in_dir(
dir: Path, *, pattern: str = r"^epoch_(\d+)"
) -> List[Path]:
"""
Returns a list of all checkpoints in the given directory.
The pattern argument is a regular expression that matches the epoch number in the checkpoint filename.
The default pattern matches filenames of the form "epoch_{epoch_number}".
Args:
dir (Path): The directory containing the checkpoints.
pattern (str): A regular expression pattern to match the epoch number in the checkpoint filename.
Defaults to "epoch_(\d+)".
Example:
>>> dir = Path("/path/to/checkpoints")
>>> pattern = r"^epoch_(\d+)"
>>> get_all_checkpoints_in_dir(dir, pattern=pattern)
[PosixPath('/path/to/checkpoints/epoch_1'), PosixPath('/path/to/checkpoints/epoch_2'), ...]
Returns:
List[Path]: A list of Path objects representing the checkpoints..
"""
checkpoints = []
regex_to_match = re.compile(pattern)

# Iterate over the directory contents
for item in dir.iterdir():
if item.is_dir():
# Check if the directory name matches the pattern
match = regex_to_match.match(item.name)
if match:
checkpoints.append(item)

return checkpoints


def prune_surplus_checkpoints(
self, checkpoints: List[Path], keep_last_n_checkpoints: int = 1
) -> None:
"""
Prunes the surplus checkpoints in the given list of checkpoints.
The function will keep the latest `keep_last_n_checkpoints` checkpoints and delete the rest.
Args:
checkpoints (List[Path]): A list of Path objects representing the checkpoints.
keep_last_n_checkpoints (int): The number of checkpoints to keep. Defaults to 1.
Note:
Expects the format of the checkpoints to be "epoch_{epoch_number}" or "step_{step_number}". A higher number
indicates a more recent checkpoint. E.g. "epoch_1" is more recent than "epoch_0".
Example:
>>> checkpoints = [PosixPath('/path/to/checkpoints/epoch_1'), PosixPath('/path/to/checkpoints/epoch_2')]
>>> prune_surplus_checkpoints(checkpoints, keep_last_n_checkpoints=1)
>>> os.listdir('/path/to/checkpoints')
['epoch_2']
"""
if keep_last_n_checkpoints < 1:
raise ValueError("keep_last_n_checkpoints must be greater than or equal to 1.")

# Sort the checkpoints by their epoch or step number
checkpoints.sort(key=lambda x: int(x.name.split("_")[-1]), reverse=True)

# Delete the surplus checkpoints
for checkpoint in checkpoints[keep_last_n_checkpoints:]:
shutil.rmtree(checkpoint)

return

0 comments on commit c5093b7

Please sign in to comment.