Skip to content

Commit

Permalink
in jax resnet50 renaming train to update_batch_norm because that's al…
Browse files Browse the repository at this point in the history
…l it's used for. adding jax resnet50 test.
  • Loading branch information
znado committed Mar 2, 2022
1 parent 11084d3 commit 8016526
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
pylint algorithmic_efficiency
pylint baselines
pylint tests
pylint submission_runner_test.py
isort:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -51,4 +51,4 @@ jobs:
pip install yapf
- name: Run yapf
run: |
yapf . --diff --recursive
yapf . --diff --recursive
9 changes: 7 additions & 2 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,13 @@ def num_train_examples(self):

@property
@abc.abstractmethod
def num_eval_examples(self):
"""The size of the evaluation set."""
def num_eval_train_examples(self):
"""The number of training examples to evaluate metrics on."""

@property
@abc.abstractmethod
def num_validation_examples(self):
"""The size of the validation set."""

@property
@abc.abstractmethod
Expand Down
12 changes: 7 additions & 5 deletions algorithmic_efficiency/workloads/imagenet/imagenet_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Flax implementation of ResNet V1."""

from functools import partial
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Optional, Sequence, Tuple

from flax import linen as nn
import jax.numpy as jnp
Expand Down Expand Up @@ -76,20 +76,22 @@ class ResNet(nn.Module):
act: Callable = nn.relu

@nn.compact
def __call__(self, x, train: bool = True):
def __call__(
self,
x,
update_batch_norm: bool = True):
conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = partial(
nn.BatchNorm,
use_running_average=not train,
use_running_average=not update_batch_norm,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)

x = conv(
self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(
x)
name='conv_init')(x)
x = norm(name='bn_init')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.imagenet.imagenet_jax import \
Expand Down Expand Up @@ -129,19 +128,18 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
train = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm:
logits, new_model_state = self._model.apply(
variables,
jax.numpy.squeeze(input_batch['image']),
train=train,
input_batch,
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
return logits, new_model_state
else:
logits = self._model.apply(
variables,
jax.numpy.squeeze(input_batch['image']),
train=train,
input_batch,
update_batch_norm=update_batch_norm,
mutable=False)
return logits, None

Expand Down Expand Up @@ -172,7 +170,7 @@ def _eval_model_on_split(self,
rng: spec.RandomState,
data_dir: str):
eval_per_core_batch_size = 256
eval_total_batch_size = eval_per_core_batch_size * jax.num_devices()
eval_total_batch_size = eval_per_core_batch_size * jax.local_device_count()
if split == 'train':
num_examples = self.num_eval_train_examples
else:
Expand Down Expand Up @@ -200,8 +198,6 @@ def _eval_model_on_split(self,
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

# eval_metrics = jax.device_get(eval_metrics)
# eval_metrics = jax.tree_multimap(lambda *x: np.stack(x), *eval_metrics)
eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics)
return eval_metrics

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for imagenet_jax/workload.py."""

from absl.testing import absltest
import jax
import jax.numpy as jnp

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.imagenet.imagenet_jax.workload import ImagenetJaxWorkload


def _pytree_total_diff(pytree_a, pytree_b):
pytree_diff = jax.tree_map(
lambda a, b: jnp.sum(a - b), pytree_a, pytree_b)
pytree_diff = jax.tree_leaves(pytree_diff)
return jnp.sum(jnp.array(pytree_diff))


class ModelsTest(absltest.TestCase):
"""Tests for imagenet_jax/workload.py."""

def test_forward_pass(self):
batch_size = 11
rng = jax.random.PRNGKey(0)
rng, model_init_rng, *data_rngs = jax.random.split(rng, 4)
workload = ImagenetJaxWorkload()
model_params, batch_stats = workload.init_model_fn(model_init_rng)
input_shape = (jax.local_device_count(), batch_size, 224, 224, 3)
first_input_batch = jax.random.normal(data_rngs[0], shape=input_shape)
expected_logits_shape = (jax.local_device_count(), batch_size, 1000)

pmapped_model_fn = jax.pmap(
workload.model_fn,
axis_name='batch',
in_axes=(0, 0, 0, None, None, None),
static_broadcasted_argnums=(3, 5))
logits, updated_batch_stats = pmapped_model_fn(
model_params,
first_input_batch,
batch_stats,
spec.ForwardPassMode.TRAIN,
rng,
True)
self.assertEqual(logits.shape, expected_logits_shape)
# Test that batch stats are updated.
self.assertNotEqual(
_pytree_total_diff(batch_stats, updated_batch_stats), 0.0)

second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape)
# Test that batch stats are not updated when we say so.
_, same_batch_stats = pmapped_model_fn(
model_params,
second_input_batch,
batch_stats,
spec.ForwardPassMode.TRAIN,
rng,
False)
self.assertIsNone(same_batch_stats)

# Test eval model.
logits, _ = pmapped_model_fn(
model_params,
second_input_batch,
batch_stats,
spec.ForwardPassMode.EVAL,
rng,
False)
self.assertEqual(logits.shape, expected_logits_shape)


if __name__ == '__main__':
absltest.main()
4 changes: 3 additions & 1 deletion submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def train_once(workload: spec.Workload,
step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
start_time = time.time()
selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection(
(selected_train_input_batch,
selected_train_label_batch,
selected_train_mask_batch) = data_selection(
workload,
input_queue,
optimizer_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import os

from algorithmic_efficiency.submission_runner import \
_convert_filepath_to_module
from submission_runner import _convert_filepath_to_module


def test_convert_filepath_to_module():
Expand Down

0 comments on commit 8016526

Please sign in to comment.