Skip to content

Commit

Permalink
fix traindiffs
Browse files Browse the repository at this point in the history
  • Loading branch information
chandramouli-sastry committed Feb 21, 2024
1 parent 25ab4ef commit b529230
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_traindiffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_workload(self, workload):
pyt_logs = '/tmp/pyt_log.pkl'
try:
run(
f'python3 -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}'
f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python3 -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}'
f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
shell=True,
stdout=DEVNULL,
Expand All @@ -60,7 +60,7 @@ def test_workload(self, workload):
print("Error:", e)
try:
run(
f'torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}'
f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}'
f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
shell=True,
stdout=DEVNULL,
Expand Down

0 comments on commit b529230

Please sign in to comment.