Skip to content

Commit

Permalink
adding reference_submission_tests.py. adding missing features to mult…
Browse files Browse the repository at this point in the history
…iple workloads. refactoring evals to use eval_model() in the root abstract class in spec.py, and have all workloads define _eval_model_on_split(). cleaning up the WMT jax workload.
  • Loading branch information
znado committed Apr 11, 2022
1 parent cc29e8e commit 5ae658c
Show file tree
Hide file tree
Showing 21 changed files with 547 additions and 478 deletions.
6 changes: 3 additions & 3 deletions RULES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def build_input_queue(
data_rng: RandomState,
split: str,
data_dir: str,
batch_size: int) -> Iterator[Dict[str, Tensor]]:
global_batch_size: int) -> Iterator[Dict[str, Tensor]]:
```

- The `build_input_queue` function will be called to produce the iterator over batches that the submitted data selection function consumes. It is responsible for all data reading, shuffling, repeating, preprocessing, and batching.
- The `build_input_queue` function will be called to produce the iterator over batches that the submitted data selection function consumes. It is responsible for all data reading, shuffling, repeating, preprocessing, and batching. Note that for Jax this should return an itertor over tensors of shape `(num_devices, per_device_batch_size, ...)`, and for PyTorch this should return tensors of shape `(global_batch_size, ...)`.

###### Model initialization

Expand Down Expand Up @@ -388,7 +388,7 @@ We will score submissions using the following algorithm described in [Benchmarki
</p>

- <img src="https://render.githubusercontent.com/render/math?math=\rho_s(\tau) = (\frac{1}{n_p}) \cdot [\text{number of problems where}\, r(p,s)\leq \tau]">

- Need to be careful about weighting tasks to not favor any data modality. We might need to weigh the problems somehow to handle different numbers of models on a given dataset

**The area between a submitted performance profile <img src="https://render.githubusercontent.com/render/math?math=\rho_s(\tau)"> and the performance profile of the reference implementation will be used as a score to compare submissions, where the area is computed by integrating <img src="https://render.githubusercontent.com/render/math?math=\log\tau"> from <img src="https://render.githubusercontent.com/render/math?math=[0, \infty)"> OR <img src="https://render.githubusercontent.com/render/math?math=\tau"> from <img src="https://render.githubusercontent.com/render/math?math=[1, \infty)"> , whether or not to log scale is a decision to be made after further investigation.**
Expand Down
55 changes: 53 additions & 2 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,14 @@ def build_input_queue(self,
data_rng: RandomState,
split: str,
data_dir: str,
batch_size: int):
global_batch_size: int):
"""Build the input queue for the workload data.
This is the only function that is NOT allowed to be called by submitters.
For Jax this should return an itertor over tensors of shape
(num_devices, per_device_batch_size, ...), and for PyTorch this should
return tensors of shape (global_batch_size, ...).
"""

@property
Expand Down Expand Up @@ -214,11 +218,58 @@ def loss_fn(
"""return oned_array_of_losses_per_example"""

@abc.abstractmethod
def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: ParameterContainer,
model_state: ModelAuxiliaryState,
rng: RandomState,
data_dir: str) -> Dict[str, float]:
"""Evaluate the model on a given dataset split, return final scalars."""

def eval_model(self,
global_batch_size: int,
params: ParameterContainer,
model_state: ModelAuxiliaryState,
rng: RandomState):
rng: RandomState,
data_dir: str) -> Dict[str, float]:
"""Run a full evaluation of the model."""
# DO NOT SUBMIT handle the case where batch size > num examples
train_metrics = self._eval_model_on_split(
split='eval_train',
num_examples=self.num_eval_train_examples,
global_batch_size=global_batch_size,
params=params,
model_state=model_state,
rng=rng,
data_dir=data_dir)
validation_metrics = self._eval_model_on_split(
'validation',
num_examples=self.num_validation_examples,
global_batch_size=global_batch_size,
params=params,
model_state=model_state,
rng=rng,
data_dir=data_dir)
eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
for k, v in validation_metrics.items():
eval_metrics['validation/' + k] = v
# Evaluate on the test set if we have one.
try:
test_metrics = self._eval_model_on_split(
'test',
num_examples=self.num_test_examples,
global_batch_size=global_batch_size,
params=params,
model_state=model_state,
rng=rng,
data_dir=data_dir)
for k, v in test_metrics.items():
eval_metrics['test/' + k] = v
except NotImplementedError:
pass
return eval_metrics


class TrainingCompleteError(Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ def preprocess_for_eval(image_bytes,
return image


def create_split(dataset_builder,
def create_split(split,
dataset_builder,
rng,
batch_size,
global_batch_size,
train,
image_size,
resize_size,
Expand All @@ -207,21 +208,9 @@ def create_split(dataset_builder,
num_batches=None,
aspect_ratio_range=(0.75, 4.0 / 3.0),
area_range=(0.08, 1.0)):
"""Creates a split from the ImageNet dataset using TensorFlow Datasets.
Args:
dataset_builder: TFDS dataset builder for ImageNet.
batch_size: the batch size returned by the data pipeline.
train: Whether to load the train or evaluation split.
image_size: The target size of the images.
cache: Whether to cache the dataset.
Returns:
A `tf.data.Dataset`.
"""
if train:
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
if split == 'eval_train':
split = 'train'
else:
split = 'validation'

shuffle_rng, preprocess_rng = jax.random.split(rng, 2)

Expand Down Expand Up @@ -265,10 +254,10 @@ def decode_example(example):

if train:
ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=shuffle_rng[0])
ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0])

ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.batch(global_batch_size, drop_remainder=True)

if num_batches is not None:
ds = ds.take(num_batches)
Expand Down Expand Up @@ -300,9 +289,10 @@ def _prepare(x):
return jax.tree_map(_prepare, xs)


def create_input_iter(dataset_builder,
def create_input_iter(split,
dataset_builder,
rng,
batch_size,
global_batch_size,
mean_rgb,
stddev_rgb,
image_size,
Expand All @@ -314,9 +304,10 @@ def create_input_iter(dataset_builder,
repeat_final_dataset,
num_batches):
ds = create_split(
split,
dataset_builder,
rng,
batch_size,
global_batch_size,
train=train,
dtype=tf.float32,
image_size=image_size,
Expand All @@ -330,7 +321,7 @@ def create_input_iter(dataset_builder,
area_range=area_range)
it = map(shard_numpy_ds, ds)

# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%
# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
it = jax_utils.prefetch_to_device(it, 2)

return it
59 changes: 18 additions & 41 deletions algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ def _build_dataset(self,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None):
if batch_size % jax.device_count() > 0:
if batch_size % jax.local_device_count() > 0:
raise ValueError('Batch size must be divisible by the number of devices')
ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
ds = input_pipeline.create_input_iter(
split,
ds_builder,
data_rng,
batch_size,
Expand Down Expand Up @@ -102,15 +103,15 @@ def output_activation_fn(self,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def eval_model_fn(self, params, batch, state, rng):
def _eval_model_fn(self, params, batch, state, rng):
logits, _ = self.model_fn(
params,
batch,
state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
return self.compute_metrics(logits, batch['label'])
return self._compute_metrics(logits, batch['label'])

def model_fn(
self,
Expand Down Expand Up @@ -146,7 +147,7 @@ def loss_fn(self, label_batch: spec.Tensor,
logits=logits_batch, labels=one_hot_labels)
return xentropy

def compute_metrics(self, logits, labels):
def _compute_metrics(self, logits, labels):
loss = jnp.mean(self.loss_fn(labels, logits))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
Expand All @@ -158,61 +159,37 @@ def compute_metrics(self, logits, labels):

def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
eval_per_core_batch_size = 256
eval_total_batch_size = eval_per_core_batch_size * jax.local_device_count()
if split == 'train':
num_examples = self.num_eval_train_examples
else:
num_examples = self.num_validation_examples
num_batches = num_examples // eval_total_batch_size
data_rng, model_rng = jax.random.split(rng, 2)
# Sync batch statistics across replicas before evaluating.
model_state = self.sync_batch_stats(model_state)
num_batches = num_examples // global_batch_size
# We already repeat the dataset indefinitely in tf.data.
if self._eval_iters[split] is None:
eval_ds = self._build_dataset(
rng,
if split not in self._eval_iters:
self._eval_iters[split] = self.build_input_queue(
data_rng,
split=split,
batch_size=eval_per_core_batch_size,
global_batch_size=global_batch_size,
data_dir=data_dir,
cache=True,
repeat_final_dataset=True,
num_batches=num_batches)
self._eval_iters[split] = iter(eval_ds)

eval_metrics = {}
for _ in range(num_batches + 1):
batch = next(self._eval_iters[split])
# We already average these metrics across devices inside compute_metrics.
synced_metrics = self.eval_model_fn(params, batch, model_state, rng)
# We already average these metrics across devices inside _compute_metrics.
synced_metrics = self._eval_model_fn(
params, batch, model_state, model_rng)
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics)
return eval_metrics

def eval_model(self,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
"""Run a full evaluation of the model."""
# Sync batch statistics across replicas before evaluating.
model_state = self.sync_batch_stats(model_state)
train_metrics = self._eval_model_on_split('train',
params,
model_state,
rng,
data_dir)
validation_metrics = self._eval_model_on_split('validation',
params,
model_state,
rng,
data_dir)
eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
for k, v in validation_metrics.items():
eval_metrics['validation/' + k] = v
return eval_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,37 @@ def cycle(iterable):

class ImagenetWorkload(BaseImagenetWorkload):

def __init__(self):
self._eval_iters = {}

@property
def param_shapes(self):
"""
TODO: return shape tuples from model as a tree
"""
raise NotImplementedError

def eval_model(self,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
"""Run a full evaluation of the model."""
# DO NOT SUBMIT use num_examples
data_rng, model_rng = prng.split(rng, 2)
eval_batch_size = 128
if self._eval_ds is None:
self._eval_ds = self._build_dataset(
data_rng, 'test', data_dir, batch_size=eval_batch_size)
if split not in self._eval_iters:
self._eval_iters[split] = self.build_input_queue(
data_rng, split, data_dir, global_batch_size=global_batch_size)

total_metrics = {
'accuracy': 0.,
'loss': 0.,
}
n_data = 0
for (images, labels) in self._eval_ds:
for (images, labels) in self._eval_iters[split]:
images = images.float().to(DEVICE)
labels = labels.float().to(DEVICE)
logits, _ = self.model_fn(
Expand All @@ -78,7 +84,6 @@ def _build_dataset(self,
split: str,
data_dir: str,
batch_size: int):

is_train = (split == "train")

normalize = transforms.Compose([
Expand All @@ -87,8 +92,13 @@ def _build_dataset(self,
mean=[i / 255 for i in self.train_mean],
std=[i / 255 for i in self.train_stddev])
])
eval_transform_config = transforms.Compose([
transforms.Resize(self.resize_size),
transforms.CenterCrop(self.center_crop_size),
normalize
])
transform_config = {
"train":
'train':
transforms.Compose([
transforms.RandomResizedCrop(
self.center_crop_size,
Expand All @@ -97,15 +107,11 @@ def _build_dataset(self,
transforms.RandomHorizontalFlip(),
normalize
]),
"test":
transforms.Compose([
transforms.Resize(self.resize_size),
transforms.CenterCrop(self.center_crop_size),
normalize
])
'eval_train': eval_transform_config,
'validation': eval_transform_config,
}

folder = {'train': 'train', 'test': 'val'}
folder = {'train': 'train', 'validation': 'val', 'eval_train': 'train'}

dataset = ImageFolder(
os.path.join(data_dir, folder[split]),
Expand Down
Loading

0 comments on commit 5ae658c

Please sign in to comment.