Skip to content

Commit

Permalink
Fix tf iter unit test (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored Jan 14, 2025
1 parent f93be34 commit 3405a6e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,11 +1223,13 @@ def test_no_save_input_iterator(self):
ckpt_dir = os.path.join(tmpdir, "tf_ckpt")
self.assertEqual(0, len(fs.listdir(tmpdir)))
# Test that when we don't save input iterator, tf dirs are not created.
async_save_tf_savables({}, executor=executor, dir=ckpt_dir)
fut = async_save_tf_savables({}, executor=executor, dir=ckpt_dir)
fut.result()
self.assertEqual([], fs.listdir(tmpdir))
# Test that dirs are created if we save.
ds = tf.data.Dataset.from_tensor_slices([])
async_save_tf_savables({"it": iter(ds)}, executor=executor, dir=ckpt_dir)
fut = async_save_tf_savables({"it": iter(ds)}, executor=executor, dir=ckpt_dir)
fut.result()
self.assertEqual(["tf_ckpt"], fs.listdir(tmpdir))


Expand Down

0 comments on commit 3405a6e

Please sign in to comment.