diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 5fb15fda6..80c0ed770 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,5 +1,4 @@ """ImageNet workload implemented in Jax.""" -import copy from typing import Dict, Optional, Tuple from flax import jax_utils diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index c2649bdb6..6a85c2196 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -28,22 +28,20 @@ def test_forward_pass(self): first_input_batch = jax.random.normal(data_rngs[0], shape=input_shape) expected_logits_shape = (jax.local_device_count(), batch_size, 1000) - # static_broadcasted_argnums=(3, 7) will recompile each time we call it in - # this file because we call it with a different combination of those two + # static_broadcasted_argnums=(3, 5) will recompile each time we call it in + # this function because we call it with a different combination of those two # args each time. Can't call with kwargs. pmapped_model_fn = jax.pmap( workload.model_fn, axis_name='batch', - in_axes=(0, 0, 0, None, None, None, None, None), - static_broadcasted_argnums=(3, 7)) + in_axes=(0, 0, 0, None, None, None), + static_broadcasted_argnums=(3, 5)) logits, updated_batch_stats = pmapped_model_fn( model_params, {'inputs': first_input_batch}, batch_stats, spec.ForwardPassMode.TRAIN, rng, - None, - None, True) self.assertEqual(logits.shape, expected_logits_shape) # Test that batch stats are updated. @@ -58,8 +56,6 @@ def test_forward_pass(self): updated_batch_stats, spec.ForwardPassMode.TRAIN, rng, - None, - None, False) self.assertEqual( _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0) @@ -71,8 +67,6 @@ def test_forward_pass(self): batch_stats, spec.ForwardPassMode.EVAL, rng, - None, - None, False) self.assertEqual(logits.shape, expected_logits_shape)