diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py index 95024215e..4096f3d08 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py @@ -262,7 +262,7 @@ def decode_example(example): if num_batches is not None: ds = ds.take(num_batches) - if not train or repeat_final_dataset: + if repeat_final_dataset: ds = ds.repeat() ds = ds.prefetch(10) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py index 56f751767..f360ac273 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py @@ -36,7 +36,7 @@ def _build_dataset(self, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None): - if batch_size % jax.local_device_count() > 0: + if batch_size % jax.local_device_count() != 0: raise ValueError('Batch size must be divisible by the number of devices') ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir) ds_builder.download_and_prepare() diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py index c5cc0371a..96bfb642f 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py @@ -2,6 +2,7 @@ import contextlib import os +import math from typing import Tuple import torch @@ -52,42 +53,6 @@ def model_params_types(self): self._param_types = param_utils.pytorch_param_types(self._param_shapes) return self._param_types - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str): - """Run a full evaluation of the model.""" - data_rng, model_rng = prng.split(rng, 2) - if split not in self._eval_iters: - self._eval_iters[split] = self.build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size) - - total_metrics = { - 'accuracy': 0., - 'loss': 0., - } - num_data = 0 - for batch in self._eval_iters[split]: - images = batch['inputs'].float().to(DEVICE) - labels = batch['targets'].to(DEVICE) - logits, _ = self.model_fn( - params, - images, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) - batch_metrics = self._eval_metric(logits, labels) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } - num_data += batch_metrics['num_data'] - return {k: float(v / num_data) for k, v in total_metrics.items()} - def _build_dataset(self, data_rng: spec.RandomState, split: str, @@ -137,8 +102,7 @@ def _build_dataset(self, pin_memory=True, drop_last=is_train) - if is_train: - dataloader = cycle(dataloader) + dataloader = cycle(dataloader) return dataloader @@ -220,3 +184,42 @@ def _eval_metric(self, logits, labels): loss = self.loss_fn(labels, logits).sum().item() num_data = len(logits) return {'accuracy': accuracy, 'loss': loss, 'num_data': num_data} + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str): + """Run a full evaluation of the model.""" + data_rng, model_rng = prng.split(rng, 2) + if split not in self._eval_iters: + # These iterators repeat indefinitely. + self._eval_iters[split] = self.build_input_queue( + data_rng, split, data_dir, global_batch_size=global_batch_size) + + total_metrics = { + 'accuracy': 0., + 'loss': 0., + } + num_data = 0 + num_batches = int(math.ceil(num_examples / global_batch_size)) + for _ in range(num_batches): + batch = next(self._eval_iters[split]) + images = batch['inputs'].float().to(DEVICE) + labels = batch['targets'].to(DEVICE) + logits, _ = self.model_fn( + params, + images, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False) + batch_metrics = self._eval_metric(logits, labels) + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } + num_data += batch_metrics['num_data'] + return {k: float(v / num_data) for k, v in total_metrics.items()} diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dd6ab78c6..dadeb34f2 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -1,5 +1,6 @@ """MNIST workload parent class.""" import itertools +import math from typing import Dict, Tuple from algorithmic_efficiency import spec @@ -99,7 +100,7 @@ def _eval_model_on_split(self, 'loss': 0., } num_data = 0 - num_batches = num_examples // global_batch_size + num_batches = int(math.ceil(num_examples / global_batch_size)) for bi, batch in enumerate(self._eval_iters[split]): if bi > num_batches: break diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 4b2723377..8dc0074ec 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -1,6 +1,7 @@ """WMT workload implemented in Jax.""" import collections import functools +import math from typing import Dict, Optional, Tuple from absl import logging diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algorithmic_efficiency/workloads/wmt/workload.py index cb6afc00e..083b91e13 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algorithmic_efficiency/workloads/wmt/workload.py @@ -105,7 +105,7 @@ def _eval_model_on_split(self, rng: spec.RandomState, data_dir: str) -> Dict[str, float]: """Run a full evaluation of the model.""" - num_batches = num_examples // global_batch_size + num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self.build_input_queue( diff --git a/tests/reference_submission_tests.py b/tests/reference_submission_tests.py index edaaa825e..9ecd4a0a3 100644 --- a/tests/reference_submission_tests.py +++ b/tests/reference_submission_tests.py @@ -43,9 +43,9 @@ } -def _make_fake_image_batch(framework, batch_shape, data_shape, num_classes): - examples = np.random.normal(size=(*batch_shape, - *data_shape)).astype(np.float32) +def _make_fake_image_batch(batch_shape, data_shape, num_classes): + examples = np.random.normal( + size=(*batch_shape, *data_shape)).astype(np.float32) labels = np.random.randint(0, num_classes, size=batch_shape) masks = np.ones((*batch_shape, *data_shape), dtype=np.float32) return {'inputs': examples, 'targets': labels, 'weights': masks} @@ -95,20 +95,20 @@ def build_input_queue(self, *args, **kwargs): if workload_name == 'mnist': fake_batch = _make_fake_image_batch( - framework, batch_shape, data_shape=(28, 28, 1), num_classes=10) + batch_shape, data_shape=(28, 28, 1), num_classes=10) elif workload_name == 'imagenet': if framework == 'jax': data_shape = (224, 224, 3) else: data_shape = (3, 224, 224) fake_batch = _make_fake_image_batch( - framework, batch_shape, data_shape=data_shape, num_classes=1000) + batch_shape, data_shape=data_shape, num_classes=1000) elif workload_name == 'librispeech': fake_batch = { - 'indices': np.random.normal(size=(8,)), - 'features': np.random.normal(size=(8, 1593, 161)), - 'transcripts': np.random.normal(size=(8, 246)), - 'input_lengths': np.random.normal(size=(8,)), + 'indices': np.ones((8,)), + 'features': np.ones((8, 1593, 161)), + 'transcripts': np.ones((8, 246)), + 'input_lengths': np.ones((8,)), } elif workload_name == 'ogbg': num_classes = 128 @@ -148,8 +148,14 @@ def build_input_queue(self, *args, **kwargs): for k, v in fake_batch.items() } # We set the number of examples to the batch size for all splits, so only - # yield one batch. - yield fake_batch + # yield two batches, one for each call to eval_model(). + num_batches = 2 + # For WMT we also iterate through the eval iters a second time to complute + # the BLEU score. + if workload_name == 'wmt': + num_batches *= 2 + for _ in range(num_batches): + yield fake_batch return _OneEvalBatchWorkload() @@ -214,18 +220,18 @@ def _test_submission(workload_name, hyperparameters, global_step, data_select_rng) - _, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - batch=batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=[], - global_step=global_step, - rng=update_rng) + # _, model_params, model_state = update_params( + # workload=workload, + # current_param_container=model_params, + # current_params_types=workload.model_params_types, + # model_state=model_state, + # hyperparameters=hyperparameters, + # batch=batch, + # loss_type=workload.loss_type, + # optimizer_state=optimizer_state, + # eval_results=[], + # global_step=global_step, + # rng=update_rng) eval_result = workload.eval_model(global_batch_size, model_params, model_state, @@ -258,9 +264,11 @@ def test_submission(self): # # DO NOT SUBMIT # if 'mnist' in submission_dir or 'imagenet' in submission_dir or 'librispeech' in submission_dir or 'ogbg' in submission_dir: # continue + # if not ('librispeech' in submission_dir): + # continue submission_path = (f'reference_submissions/{workload_name}/' f'{workload_name}_{framework}/submission.py') - logging.info(f'\n\n========= Testing {workload_name} in {framework}.') + logging.info(f'========= Testing {workload_name} in {framework}.') eval_result = _test_submission( workload_name, framework,