From c5093b703e5125db6161d54b0871448dc975ab3f Mon Sep 17 00:00:00 2001 From: joecummings Date: Mon, 10 Feb 2025 12:52:29 -0800 Subject: [PATCH] Add helper functions for pruning old checkpoints --- .../checkpointing/test_checkpointer_utils.py | 92 +++++++++++++++++++ torchtune/training/checkpointing/_utils.py | 70 ++++++++++++++ 2 files changed, 162 insertions(+) diff --git a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py index 86f84d9a43..f847d21500 100644 --- a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py +++ b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py @@ -13,6 +13,8 @@ from torchtune.training.checkpointing._utils import ( check_outdir_not_in_ckptdir, FormattedCheckpointFiles, + get_all_checkpoints_in_dir, + prune_surplus_checkpoints, safe_torch_load, update_state_dict_for_classifier, ) @@ -271,3 +273,93 @@ def test_output_dir_ckpt_dir_few_levels_down(self): match="The output directory cannot be the same as or a subdirectory of the checkpoint directory.", ): check_outdir_not_in_ckptdir(ckpt_dir, out_dir) + + +class TestGetAllCheckpointsInDir: + """Series of tests for the ``get_all_checkpoints_in_dir`` function.""" + + def test_get_all_ckpts_simple(self, tmp_dir): + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "epoch_1" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + assert len(all_ckpts) == 2 + assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] + + def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir): + """Test that we only return the checkpoints that match the pattern.""" + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "step_1" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + assert len(all_ckpts) == 1 + assert all_ckpts == [ckpt_dir_0] + + def test_get_all_ckpts_override_pattern(self, tmp_dir): + """Test that we can override the default pattern and it works.""" + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "step_1" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*") + assert len(all_ckpts) == 1 + assert all_ckpts == [ckpt_dir_1] + + def test_get_all_ckpts_only_return_dirs(self, tmp_dir): + """Test that even if a file matches the pattern, we only return directories.""" + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + file = tmp_dir / "epoch_1" + ckpt_dir_1.touch() + + all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + assert len(all_ckpts) == 1 + assert all_ckpts == [ckpt_dir_0] + + def test_get_all_ckpts_non_unique(self, tmp_dir): + """Test that we return all checkpoints, even if they have the same name.""" + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "epoch_0" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + assert len(all_ckpts) == 2 + assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] + + +class TestPruneSurplusCheckpoints: + """Series of tests for the ``prune_surplus_checkpoints`` function.""" + + def test_prune_surplus_checkpoints_simple(self, tmp_dir): + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "epoch_1" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + prune_surplus_checkpoints(tmp_dir, 1) + remaining_ckpts = os.listdir(tmp_dir) + assert len(remaining_ckpts) == 1 + assert remaining_ckpts == ["epoch_1"] + + def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir): + """Test that we raise an error if keep_last_n_checkpoints is not >= 1""" + ckpt_dir_0 = tmp_dir / "epoch_0" + ckpt_dir_0.mkdir(parents=True, exist_ok=True) + + ckpt_dir_1 = tmp_dir / "epoch_1" + ckpt_dir_1.mkdir(parents=True, exist_ok=True) + + with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"): + prune_surplus_checkpoints(tmp_dir, 0) diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 618eebfb59..418525c0b2 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -594,3 +594,73 @@ def check_outdir_not_in_ckptdir(ckpt_dir: Path, out_dir: Path) -> bool: ) return True + + +def get_all_checkpoints_in_dir( + dir: Path, *, pattern: str = r"^epoch_(\d+)" +) -> List[Path]: + """ + Returns a list of all checkpoints in the given directory. + The pattern argument is a regular expression that matches the epoch number in the checkpoint filename. + The default pattern matches filenames of the form "epoch_{epoch_number}". + + Args: + dir (Path): The directory containing the checkpoints. + pattern (str): A regular expression pattern to match the epoch number in the checkpoint filename. + Defaults to "epoch_(\d+)". + + Example: + >>> dir = Path("/path/to/checkpoints") + >>> pattern = r"^epoch_(\d+)" + >>> get_all_checkpoints_in_dir(dir, pattern=pattern) + [PosixPath('/path/to/checkpoints/epoch_1'), PosixPath('/path/to/checkpoints/epoch_2'), ...] + + Returns: + List[Path]: A list of Path objects representing the checkpoints.. + """ + checkpoints = [] + regex_to_match = re.compile(pattern) + + # Iterate over the directory contents + for item in dir.iterdir(): + if item.is_dir(): + # Check if the directory name matches the pattern + match = regex_to_match.match(item.name) + if match: + checkpoints.append(item) + + return checkpoints + + +def prune_surplus_checkpoints( + self, checkpoints: List[Path], keep_last_n_checkpoints: int = 1 +) -> None: + """ + Prunes the surplus checkpoints in the given list of checkpoints. + The function will keep the latest `keep_last_n_checkpoints` checkpoints and delete the rest. + + Args: + checkpoints (List[Path]): A list of Path objects representing the checkpoints. + keep_last_n_checkpoints (int): The number of checkpoints to keep. Defaults to 1. + + Note: + Expects the format of the checkpoints to be "epoch_{epoch_number}" or "step_{step_number}". A higher number + indicates a more recent checkpoint. E.g. "epoch_1" is more recent than "epoch_0". + + Example: + >>> checkpoints = [PosixPath('/path/to/checkpoints/epoch_1'), PosixPath('/path/to/checkpoints/epoch_2')] + >>> prune_surplus_checkpoints(checkpoints, keep_last_n_checkpoints=1) + >>> os.listdir('/path/to/checkpoints') + ['epoch_2'] + """ + if keep_last_n_checkpoints < 1: + raise ValueError("keep_last_n_checkpoints must be greater than or equal to 1.") + + # Sort the checkpoints by their epoch or step number + checkpoints.sort(key=lambda x: int(x.name.split("_")[-1]), reverse=True) + + # Delete the surplus checkpoints + for checkpoint in checkpoints[keep_last_n_checkpoints:]: + shutil.rmtree(checkpoint) + + return