Skip to content

Commit

Permalink
starting to add jax evals on train split. attempt at deterministic ja…
Browse files Browse the repository at this point in the history
…x imagenet input pipeline.
  • Loading branch information
znado committed Mar 2, 2022
1 parent d17e0ef commit 11084d3
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]


def distorted_bounding_box_crop(image_bytes,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100):
def _distorted_bounding_box_crop(image_bytes,
rng,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image_bytes: `Tensor` of binary image data.
rng: a per-example, per-step unique RNG seed.
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
where each coordinate is [0, 1) and the coordinates are arranged
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
Expand All @@ -44,8 +46,9 @@ def distorted_bounding_box_crop(image_bytes,
cropped image `Tensor`
"""
shape = tf.io.extract_jpeg_shape(image_bytes)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
sample_distorted_bounding_box = tf.image.stateless_sample_distorted_bounding_box(
shape,
seed=rng,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
Expand All @@ -59,7 +62,6 @@ def distorted_bounding_box_crop(image_bytes,
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

return image


Expand All @@ -76,14 +78,16 @@ def _at_least_x_are_equal(a, b, x):


def _decode_and_random_crop(image_bytes,
rng,
image_size,
aspect_ratio_range,
area_range,
resize_size=RESIZE_SIZE):
"""Make a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = distorted_bounding_box_crop(
image = _distorted_bounding_box_crop(
image_bytes,
rng,
bbox,
min_object_covered=0.1,
aspect_ratio_range=aspect_ratio_range,
Expand Down Expand Up @@ -132,6 +136,7 @@ def normalize_image(image, mean_rgb, stddev_rgb):


def preprocess_for_train(image_bytes,
rng,
mean_rgb,
stddev_rgb,
aspect_ratio_range,
Expand All @@ -143,19 +148,23 @@ def preprocess_for_train(image_bytes,
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
rng: a per-example, per-step unique RNG seed.
dtype: data type of the image.
image_size: image size.
Returns:
A preprocessed image `Tensor`.
"""
crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2)

image = _decode_and_random_crop(image_bytes,
crop_rng,
image_size,
aspect_ratio_range,
area_range,
resize_size)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng)
image = normalize_image(image, mean_rgb, stddev_rgb)
image = tf.image.convert_image_dtype(image, dtype=dtype)
return image
Expand Down Expand Up @@ -185,14 +194,16 @@ def preprocess_for_eval(image_bytes,


def create_split(dataset_builder,
rng,
batch_size,
train,
dtype=tf.float32,
image_size=IMAGE_SIZE,
resize_size=RESIZE_SIZE,
mean_rgb=MEAN_RGB,
stddev_rgb=STDDEV_RGB,
image_size,
resize_size,
mean_rgb,
stddev_rgb,
cache=False,
repeat_final_dataset=False,
num_batches=None,
aspect_ratio_range=(0.75, 4.0 / 3.0),
area_range=(0.08, 1.0)):
"""Creates a split from the ImageNet dataset using TensorFlow Datasets.
Expand All @@ -201,7 +212,6 @@ def create_split(dataset_builder,
dataset_builder: TFDS dataset builder for ImageNet.
batch_size: the batch size returned by the data pipeline.
train: Whether to load the train or evaluation split.
dtype: data type of the image.
image_size: The target size of the images.
cache: Whether to cache the dataset.
Returns:
Expand All @@ -212,9 +222,19 @@ def create_split(dataset_builder,
else:
split = 'validation'

shuffle_rng, preprocess_rng = jax.random.split(rng, 2)

def decode_example(example):
dtype = tf.float32
if train:
# We call ds.enumerate() to get a globally unique per-example, per-step
# index that we can fold into the RNG seed.
(example_index, example) = example
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
tf.cast(preprocess_rng, tf.int64), example_index)
image = preprocess_for_train(example['image'],
per_step_preprocess_rng,
example_index,
mean_rgb,
stddev_rgb,
aspect_ratio_range,
Expand All @@ -232,7 +252,8 @@ def decode_example(example):
return {'image': image, 'label': example['label']}

ds = dataset_builder.as_dataset(
split=split, decoders={
split=split,
decoders={
'image': tfds.decode.SkipDecoding(),
})
options = tf.data.Options()
Expand All @@ -244,12 +265,15 @@ def decode_example(example):

if train:
ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=0)
ds = ds.shuffle(16 * batch_size, seed=shuffle_rng[0])

ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)

if not train:
if num_batches is not None:
ds = ds.take(num_batches)

if not train or repeat_final_dataset:
ds = ds.repeat()

ds = ds.prefetch(10)
Expand Down Expand Up @@ -277,6 +301,7 @@ def _prepare(x):


def create_input_iter(dataset_builder,
rng,
batch_size,
mean_rgb,
stddev_rgb,
Expand All @@ -285,16 +310,21 @@ def create_input_iter(dataset_builder,
aspect_ratio_range,
area_range,
train,
cache):
cache,
repeat_final_dataset,
num_batches):
ds = create_split(
dataset_builder,
rng,
batch_size,
train=train,
image_size=image_size,
resize_size=resize_size,
mean_rgb=mean_rgb,
stddev_rgb=stddev_rgb,
cache=cache,
repeat_final_dataset=repeat_final_dataset,
num_batches=num_batches,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range)
it = map(shard_numpy_ds, ds)
Expand Down
96 changes: 63 additions & 33 deletions algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
"""ImageNet workload implemented in Jax.
python3 submission_runner.py \
--workload=imagenet_jax \
--submission_path=workloads/imagenet/imagenet_jax/submission.py \
--num_tuning_trials=1
"""
"""ImageNet workload implemented in Jax."""
import functools
from typing import Tuple
from typing import Optional, Tuple

import optax
import tensorflow as tf
Expand Down Expand Up @@ -37,27 +31,35 @@ def __init__(self):
super().__init__()
self._param_shapes = None
self.epoch_metrics = []
self._eval_iters = {}

def _build_dataset(self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
batch_size):
batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None):
if batch_size % jax.device_count() > 0:
raise ValueError('Batch size must be divisible by the number of devices')
ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
ds = input_pipeline.create_input_iter(
ds_builder,
data_rng,
batch_size,
self.train_mean,
self.train_stddev,
self.center_crop_size,
self.resize_size,
self.aspect_ratio_range,
self.scale_ratio_range,
train=split == 'train',
cache=False)
train=train,
cache=not train if cache is None else cache,
repeat_final_dataset=repeat_final_dataset,
num_batches=num_batches)
return ds

def sync_batch_stats(self, model_state):
Expand Down Expand Up @@ -95,8 +97,7 @@ def init_model_fn(self, rng: spec.RandomState) -> _InitState:
params = jax_utils.replicate(params)
return params, model_state

# Keep this separate from the loss function in order to support optimizers

# Keep this separate from the loss function in order to support optimizers
# that use the logits.
def output_activation_fn(self,
logits_batch: spec.Tensor,
Expand Down Expand Up @@ -164,30 +165,59 @@ def compute_metrics(self, logits, labels):
metrics = lax.pmean(metrics, axis_name='batch')
return metrics

def _eval_model_on_split(self,
split: str,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
eval_per_core_batch_size = 256
eval_total_batch_size = eval_per_core_batch_size * jax.num_devices()
if split == 'train':
num_examples = self.num_eval_train_examples
else:
num_examples = self.num_validation_examples
num_batches = num_examples // eval_total_batch_size
# We already repeat the dataset indefinitely in tf.data.
if self._eval_iters[split] is None:
eval_ds = self._build_dataset(
rng,
split=split,
batch_size=eval_per_core_batch_size,
data_dir=data_dir,
cache=True,
repeat_final_dataset=True,
num_batches=num_batches)
self._eval_iters[split] = iter(eval_ds)

eval_metrics = {}
for _ in range(num_batches + 1):
batch = next(self._eval_iters[split])
# We already average these metrics across devices inside compute_metrics.
synced_metrics = self.eval_model_fn(params, batch, model_state, rng)
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

# eval_metrics = jax.device_get(eval_metrics)
# eval_metrics = jax.tree_multimap(lambda *x: np.stack(x), *eval_metrics)
eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics)
return eval_metrics

def eval_model(self,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str):
"""Run a full evaluation of the model."""
# sync batch statistics across replicas
# Sync batch statistics across replicas before evaluating.
model_state = self.sync_batch_stats(model_state)

eval_metrics = []
eval_batch_size = 200
num_batches = self.num_eval_examples // eval_batch_size
if self._eval_ds is None:
self._eval_ds = self._build_dataset(
rng, split='test', batch_size=eval_batch_size, data_dir=data_dir)
eval_iter = iter(self._eval_ds)
total_accuracy = 0.
for _ in range(num_batches):
batch = next(eval_iter)
synced_metrics = self.eval_model_fn(params, batch, model_state, rng)
eval_metrics.append(synced_metrics)
total_accuracy += jnp.mean(synced_metrics['accuracy'])

eval_metrics = jax.device_get(jax.tree_map(lambda x: x[0], eval_metrics))
eval_metrics = jax.tree_multimap(lambda *x: np.stack(x), *eval_metrics)
summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
return summary
train_metrics = self._eval_model_on_split(
'train', params, model_state, rng, data_dir)
validation_metrics = self._eval_model_on_split(
'validation', params, model_state, rng, data_dir)
eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
for k, v in validation_metrics.items():
eval_metrics['validation/' + k] = v
return eval_metrics
6 changes: 5 additions & 1 deletion algorithmic_efficiency/workloads/imagenet/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def num_train_examples(self):
return 1281167

@property
def num_eval_examples(self):
def num_eval_train_examples(self):
return 50000

@property
def num_validation_examples(self):
return 50000

@property
Expand Down
3 changes: 0 additions & 3 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def _convert_filepath_to_module(path: str):


def _import_workload(workload_path: str,
workload_registry_name: str,
workload_class_name: str) -> spec.Workload:
"""Import and add the workload to the registry.
Expand All @@ -121,7 +120,6 @@ def _import_workload(workload_path: str,
Args:
workload_path: the path to the `workload.py` file to load.
workload_registry_name: the name to register the workload class under.
workload_class_name: the name of the Workload class that implements the
`Workload` abstract class in `spec.py`.
"""
Expand Down Expand Up @@ -302,7 +300,6 @@ def main(_):
workload_metadata = WORKLOADS[FLAGS.workload]
workload = _import_workload(
workload_path=workload_metadata['workload_path'],
workload_registry_name=FLAGS.workload,
workload_class_name=workload_metadata['workload_class_name'])

score = score_submission_on_workload(workload,
Expand Down

0 comments on commit 11084d3

Please sign in to comment.