Skip to content

Commit

Permalink
Fix helper function tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Feb 11, 2025
1 parent b0648fd commit 20f2acf
Showing 1 changed file with 42 additions and 44 deletions.
86 changes: 42 additions & 44 deletions tests/torchtune/training/checkpointing/test_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from copy import deepcopy
from pathlib import Path

Expand Down Expand Up @@ -278,88 +279,85 @@ def test_output_dir_ckpt_dir_few_levels_down(self):
class TestGetAllCheckpointsInDir:
"""Series of tests for the ``get_all_checkpoints_in_dir`` function."""

def test_get_all_ckpts_simple(self, tmp_dir):
ckpt_dir_0 = tmp_dir / "epoch_0"
def test_get_all_ckpts_simple(self, tmpdir):
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
ckpt_dir_1 = tmpdir / "epoch_1"
ckpt_dir_1.mkdir()

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
assert len(all_ckpts) == 2
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]

def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir):
def test_get_all_ckpts_with_pattern_that_matches_some(self, tmpdir):
"""Test that we only return the checkpoints that match the pattern."""
ckpt_dir_0 = tmp_dir / "epoch_0"
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "step_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
ckpt_dir_1 = tmpdir / "step_1"
ckpt_dir_1.mkdir()

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_0]

def test_get_all_ckpts_override_pattern(self, tmp_dir):
def test_get_all_ckpts_override_pattern(self, tmpdir):
"""Test that we can override the default pattern and it works."""
ckpt_dir_0 = tmp_dir / "epoch_0"
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "step_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
ckpt_dir_1 = tmpdir / "step_1"
ckpt_dir_1.mkdir()

all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*")
all_ckpts = get_all_checkpoints_in_dir(tmpdir, pattern="step_*")
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_1]

def test_get_all_ckpts_only_return_dirs(self, tmp_dir):
def test_get_all_ckpts_only_return_dirs(self, tmpdir):
"""Test that even if a file matches the pattern, we only return directories."""
ckpt_dir_0 = tmp_dir / "epoch_0"
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

file = tmp_dir / "epoch_1"
ckpt_dir_1.touch()
file = tmpdir / "epoch_1"
file.touch()

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
assert len(all_ckpts) == 1
assert all_ckpts == [ckpt_dir_0]

def test_get_all_ckpts_non_unique(self, tmp_dir):
"""Test that we return all checkpoints, even if they have the same name."""
ckpt_dir_0 = tmp_dir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_0"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)

all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
assert len(all_ckpts) == 2
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]


class TestPruneSurplusCheckpoints:
"""Series of tests for the ``prune_surplus_checkpoints`` function."""

def test_prune_surplus_checkpoints_simple(self, tmp_dir):
ckpt_dir_0 = tmp_dir / "epoch_0"
def test_prune_surplus_checkpoints_simple(self, tmpdir):
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
ckpt_dir_1 = tmpdir / "epoch_1"
ckpt_dir_1.mkdir()

prune_surplus_checkpoints(tmp_dir, 1)
remaining_ckpts = os.listdir(tmp_dir)
prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 1)
remaining_ckpts = os.listdir(tmpdir)
assert len(remaining_ckpts) == 1
assert remaining_ckpts == ["epoch_1"]

def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir):
def test_prune_surplus_checkpoints_keep_last_invalid(self, tmpdir):
"""Test that we raise an error if keep_last_n_checkpoints is not >= 1"""
ckpt_dir_0 = tmp_dir / "epoch_0"
tmpdir = Path(tmpdir)
ckpt_dir_0 = tmpdir / "epoch_0"
ckpt_dir_0.mkdir(parents=True, exist_ok=True)

ckpt_dir_1 = tmp_dir / "epoch_1"
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
ckpt_dir_1 = tmpdir / "epoch_1"
ckpt_dir_1.mkdir()

with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"):
prune_surplus_checkpoints(tmp_dir, 0)
with pytest.raises(
ValueError,
match="keep_last_n_checkpoints must be greater than or equal to 1",
):
prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 0)

0 comments on commit 20f2acf

Please sign in to comment.