Skip to content

Commit

Permalink
updating tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Feb 13, 2025
1 parent 54a4dec commit 249f3b0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 7 additions & 2 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def test_loss_single_rank(
[
("llama3/8B_full", "llama3", "tune", 1, 4, False),
],
[
("llama3/8B_full", "llama3", "tune", 4, 1, True),
],
)
@gpu_test(gpu_count=2)
def test_training_state_on_resume(
Expand Down Expand Up @@ -306,7 +309,8 @@ def test_training_state_on_resume(
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
clip_grad_norm=100 \
clip_grad_norm=f"{'100' if not optim_in_bwd else 'null'}" \
optimizer_in_bwd={optim_in_bwd} \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type]
Expand Down Expand Up @@ -338,7 +342,8 @@ def test_training_state_on_resume(
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
clip_grad_norm=100 \
clip_grad_norm=f"{'100' if not optim_in_bwd else 'null'}" \
optimizer_in_bwd={optim_in_bwd} \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
Expand Down
8 changes: 7 additions & 1 deletion tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_loss(
)

@pytest.mark.integration_test
def test_training_state_on_resume(self, tmpdir, monkeypatch):
@pytest.mark.parametrize(
"optimizer_in_bwd",
[True, False],
)
def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
"""Test whether the recipe state is correctly updated on resume. Since this
is model agnostic, we should run this on the small model only. The test
consists of three stages:
Expand Down Expand Up @@ -169,6 +173,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
optimizer_in_bwd={optimizer_in_bwd} \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2"]
Expand Down Expand Up @@ -200,6 +205,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
optimizer_in_bwd={optimizer_in_bwd} \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
Expand Down

0 comments on commit 249f3b0

Please sign in to comment.