From 8016526158c90b4acfe0df4d0a5eb0d95c3fa8d3 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:29:23 -0500 Subject: [PATCH] in jax resnet50 renaming train to update_batch_norm because that's all it's used for. adding jax resnet50 test. --- .github/workflows/linting.yml | 4 +- algorithmic_efficiency/spec.py | 9 ++- .../workloads/imagenet/imagenet_jax/models.py | 12 ++-- .../imagenet/imagenet_jax/workload.py | 14 ++-- .../imagenet/imagenet_jax/workload_test.py | 71 +++++++++++++++++++ submission_runner.py | 4 +- ...ion_runner.py => submission_runner_test.py | 3 +- 7 files changed, 96 insertions(+), 21 deletions(-) create mode 100644 algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload_test.py rename tests/test_submission_runner.py => submission_runner_test.py (81%) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index b3395882a..399614f58 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -19,7 +19,7 @@ jobs: run: | pylint algorithmic_efficiency pylint baselines - pylint tests + pylint submission_runner_test.py isort: runs-on: ubuntu-latest @@ -51,4 +51,4 @@ jobs: pip install yapf - name: Run yapf run: | - yapf . --diff --recursive \ No newline at end of file + yapf . --diff --recursive diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 3a1b2f33b..5bd379fd8 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -138,8 +138,13 @@ def num_train_examples(self): @property @abc.abstractmethod - def num_eval_examples(self): - """The size of the evaluation set.""" + def num_eval_train_examples(self): + """The number of training examples to evaluate metrics on.""" + + @property + @abc.abstractmethod + def num_validation_examples(self): + """The size of the validation set.""" @property @abc.abstractmethod diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/models.py index 92318be27..357e27290 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/models.py @@ -3,7 +3,7 @@ """Flax implementation of ResNet V1.""" from functools import partial -from typing import Any, Callable, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple from flax import linen as nn import jax.numpy as jnp @@ -76,11 +76,14 @@ class ResNet(nn.Module): act: Callable = nn.relu @nn.compact - def __call__(self, x, train: bool = True): + def __call__( + self, + x, + update_batch_norm: bool = True): conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) norm = partial( nn.BatchNorm, - use_running_average=not train, + use_running_average=not update_batch_norm, momentum=0.9, epsilon=1e-5, dtype=self.dtype) @@ -88,8 +91,7 @@ def __call__(self, x, train: bool = True): x = conv( self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], - name='conv_init')( - x) + name='conv_init')(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py index c8ebdc7e4..6b1022025 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py @@ -13,7 +13,6 @@ import jax from jax import lax import jax.numpy as jnp -import numpy as np from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet.imagenet_jax import \ @@ -129,19 +128,18 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} - train = mode == spec.ForwardPassMode.TRAIN if update_batch_norm: logits, new_model_state = self._model.apply( variables, - jax.numpy.squeeze(input_batch['image']), - train=train, + input_batch, + update_batch_norm=update_batch_norm, mutable=['batch_stats']) return logits, new_model_state else: logits = self._model.apply( variables, - jax.numpy.squeeze(input_batch['image']), - train=train, + input_batch, + update_batch_norm=update_batch_norm, mutable=False) return logits, None @@ -172,7 +170,7 @@ def _eval_model_on_split(self, rng: spec.RandomState, data_dir: str): eval_per_core_batch_size = 256 - eval_total_batch_size = eval_per_core_batch_size * jax.num_devices() + eval_total_batch_size = eval_per_core_batch_size * jax.local_device_count() if split == 'train': num_examples = self.num_eval_train_examples else: @@ -200,8 +198,6 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - # eval_metrics = jax.device_get(eval_metrics) - # eval_metrics = jax.tree_multimap(lambda *x: np.stack(x), *eval_metrics) eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload_test.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload_test.py new file mode 100644 index 000000000..deb846901 --- /dev/null +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload_test.py @@ -0,0 +1,71 @@ +"""Tests for imagenet_jax/workload.py.""" + +from absl.testing import absltest +import jax +import jax.numpy as jnp + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet.imagenet_jax.workload import ImagenetJaxWorkload + + +def _pytree_total_diff(pytree_a, pytree_b): + pytree_diff = jax.tree_map( + lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) + pytree_diff = jax.tree_leaves(pytree_diff) + return jnp.sum(jnp.array(pytree_diff)) + + +class ModelsTest(absltest.TestCase): + """Tests for imagenet_jax/workload.py.""" + + def test_forward_pass(self): + batch_size = 11 + rng = jax.random.PRNGKey(0) + rng, model_init_rng, *data_rngs = jax.random.split(rng, 4) + workload = ImagenetJaxWorkload() + model_params, batch_stats = workload.init_model_fn(model_init_rng) + input_shape = (jax.local_device_count(), batch_size, 224, 224, 3) + first_input_batch = jax.random.normal(data_rngs[0], shape=input_shape) + expected_logits_shape = (jax.local_device_count(), batch_size, 1000) + + pmapped_model_fn = jax.pmap( + workload.model_fn, + axis_name='batch', + in_axes=(0, 0, 0, None, None, None), + static_broadcasted_argnums=(3, 5)) + logits, updated_batch_stats = pmapped_model_fn( + model_params, + first_input_batch, + batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + True) + self.assertEqual(logits.shape, expected_logits_shape) + # Test that batch stats are updated. + self.assertNotEqual( + _pytree_total_diff(batch_stats, updated_batch_stats), 0.0) + + second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape) + # Test that batch stats are not updated when we say so. + _, same_batch_stats = pmapped_model_fn( + model_params, + second_input_batch, + batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + False) + self.assertIsNone(same_batch_stats) + + # Test eval model. + logits, _ = pmapped_model_fn( + model_params, + second_input_batch, + batch_stats, + spec.ForwardPassMode.EVAL, + rng, + False) + self.assertEqual(logits.shape, expected_logits_shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/submission_runner.py b/submission_runner.py index 882b47480..3df78de22 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -184,7 +184,9 @@ def train_once(workload: spec.Workload, step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) start_time = time.time() - selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( + (selected_train_input_batch, + selected_train_label_batch, + selected_train_mask_batch) = data_selection( workload, input_queue, optimizer_state, diff --git a/tests/test_submission_runner.py b/submission_runner_test.py similarity index 81% rename from tests/test_submission_runner.py rename to submission_runner_test.py index 24fddb44a..a79d0499d 100644 --- a/tests/test_submission_runner.py +++ b/submission_runner_test.py @@ -2,8 +2,7 @@ import os -from algorithmic_efficiency.submission_runner import \ - _convert_filepath_to_module +from submission_runner import _convert_filepath_to_module def test_convert_filepath_to_module():