diff --git a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py index f847d21500..062b0c3ba0 100644 --- a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py +++ b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os from copy import deepcopy from pathlib import Path @@ -278,88 +279,85 @@ def test_output_dir_ckpt_dir_few_levels_down(self): class TestGetAllCheckpointsInDir: """Series of tests for the ``get_all_checkpoints_in_dir`` function.""" - def test_get_all_ckpts_simple(self, tmp_dir): - ckpt_dir_0 = tmp_dir / "epoch_0" + def test_get_all_ckpts_simple(self, tmpdir): + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - ckpt_dir_1 = tmp_dir / "epoch_1" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) + ckpt_dir_1 = tmpdir / "epoch_1" + ckpt_dir_1.mkdir() - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + all_ckpts = get_all_checkpoints_in_dir(tmpdir) assert len(all_ckpts) == 2 assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] - def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir): + def test_get_all_ckpts_with_pattern_that_matches_some(self, tmpdir): """Test that we only return the checkpoints that match the pattern.""" - ckpt_dir_0 = tmp_dir / "epoch_0" + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - ckpt_dir_1 = tmp_dir / "step_1" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) + ckpt_dir_1 = tmpdir / "step_1" + ckpt_dir_1.mkdir() - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + all_ckpts = get_all_checkpoints_in_dir(tmpdir) assert len(all_ckpts) == 1 assert all_ckpts == [ckpt_dir_0] - def test_get_all_ckpts_override_pattern(self, tmp_dir): + def test_get_all_ckpts_override_pattern(self, tmpdir): """Test that we can override the default pattern and it works.""" - ckpt_dir_0 = tmp_dir / "epoch_0" + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - ckpt_dir_1 = tmp_dir / "step_1" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) + ckpt_dir_1 = tmpdir / "step_1" + ckpt_dir_1.mkdir() - all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*") + all_ckpts = get_all_checkpoints_in_dir(tmpdir, pattern="step_*") assert len(all_ckpts) == 1 assert all_ckpts == [ckpt_dir_1] - def test_get_all_ckpts_only_return_dirs(self, tmp_dir): + def test_get_all_ckpts_only_return_dirs(self, tmpdir): """Test that even if a file matches the pattern, we only return directories.""" - ckpt_dir_0 = tmp_dir / "epoch_0" + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - file = tmp_dir / "epoch_1" - ckpt_dir_1.touch() + file = tmpdir / "epoch_1" + file.touch() - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) + all_ckpts = get_all_checkpoints_in_dir(tmpdir) assert len(all_ckpts) == 1 assert all_ckpts == [ckpt_dir_0] - def test_get_all_ckpts_non_unique(self, tmp_dir): - """Test that we return all checkpoints, even if they have the same name.""" - ckpt_dir_0 = tmp_dir / "epoch_0" - ckpt_dir_0.mkdir(parents=True, exist_ok=True) - - ckpt_dir_1 = tmp_dir / "epoch_0" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) - - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) - assert len(all_ckpts) == 2 - assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] - class TestPruneSurplusCheckpoints: """Series of tests for the ``prune_surplus_checkpoints`` function.""" - def test_prune_surplus_checkpoints_simple(self, tmp_dir): - ckpt_dir_0 = tmp_dir / "epoch_0" + def test_prune_surplus_checkpoints_simple(self, tmpdir): + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - ckpt_dir_1 = tmp_dir / "epoch_1" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) + ckpt_dir_1 = tmpdir / "epoch_1" + ckpt_dir_1.mkdir() - prune_surplus_checkpoints(tmp_dir, 1) - remaining_ckpts = os.listdir(tmp_dir) + prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 1) + remaining_ckpts = os.listdir(tmpdir) assert len(remaining_ckpts) == 1 assert remaining_ckpts == ["epoch_1"] - def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir): + def test_prune_surplus_checkpoints_keep_last_invalid(self, tmpdir): """Test that we raise an error if keep_last_n_checkpoints is not >= 1""" - ckpt_dir_0 = tmp_dir / "epoch_0" + tmpdir = Path(tmpdir) + ckpt_dir_0 = tmpdir / "epoch_0" ckpt_dir_0.mkdir(parents=True, exist_ok=True) - ckpt_dir_1 = tmp_dir / "epoch_1" - ckpt_dir_1.mkdir(parents=True, exist_ok=True) + ckpt_dir_1 = tmpdir / "epoch_1" + ckpt_dir_1.mkdir() - with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"): - prune_surplus_checkpoints(tmp_dir, 0) + with pytest.raises( + ValueError, + match="keep_last_n_checkpoints must be greater than or equal to 1", + ): + prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 0)