Skip to content

Commit

Permalink
fixing lint and test
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Oct 24, 2022
1 parent c6e9aae commit 4109e8a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""ImageNet workload implemented in Jax."""
import copy
from typing import Dict, Optional, Tuple

from flax import jax_utils
Expand Down
14 changes: 4 additions & 10 deletions tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@ def test_forward_pass(self):
first_input_batch = jax.random.normal(data_rngs[0], shape=input_shape)
expected_logits_shape = (jax.local_device_count(), batch_size, 1000)

# static_broadcasted_argnums=(3, 7) will recompile each time we call it in
# this file because we call it with a different combination of those two
# static_broadcasted_argnums=(3, 5) will recompile each time we call it in
# this function because we call it with a different combination of those two
# args each time. Can't call with kwargs.
pmapped_model_fn = jax.pmap(
workload.model_fn,
axis_name='batch',
in_axes=(0, 0, 0, None, None, None, None, None),
static_broadcasted_argnums=(3, 7))
in_axes=(0, 0, 0, None, None, None),
static_broadcasted_argnums=(3, 5))
logits, updated_batch_stats = pmapped_model_fn(
model_params,
{'inputs': first_input_batch},
batch_stats,
spec.ForwardPassMode.TRAIN,
rng,
None,
None,
True)
self.assertEqual(logits.shape, expected_logits_shape)
# Test that batch stats are updated.
Expand All @@ -58,8 +56,6 @@ def test_forward_pass(self):
updated_batch_stats,
spec.ForwardPassMode.TRAIN,
rng,
None,
None,
False)
self.assertEqual(
_pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0)
Expand All @@ -71,8 +67,6 @@ def test_forward_pass(self):
batch_stats,
spec.ForwardPassMode.EVAL,
rng,
None,
None,
False)
self.assertEqual(logits.shape, expected_logits_shape)

Expand Down

0 comments on commit 4109e8a

Please sign in to comment.