diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md
index 6037f285f..821902003 100644
--- a/DOCUMENTATION.md
+++ b/DOCUMENTATION.md
@@ -400,12 +400,12 @@ The currently eight fixed workloads are:
| | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation
**Target** | Test
**Target** | Maximum
**Runtime**
(in secs) |
|------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------|
-| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123649 | 0.126060 | 21,600 |
-| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.7344 | 0.741652 | 10,800 |
-| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 111,600
111,600 |
-| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.078477
0.1162 | 0.046973
0.068093 |
72,000 |
-| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,000 |
-| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 80,000 |
+| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 |
+| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 |
+| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 63,008
77,520 |
+| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 61,068
55,506 |
+| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
+| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |
#### Randomized workloads
diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py
index 0bc1cd8cc..93f603bca 100644
--- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def initialized(self, key: spec.RandomState,
model: nn.Module) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
- variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape))
+ params_rng, dropout_rng = jax.random.split(key)
+ variables = jax.jit(
+ model.init)({'params': params_rng, 'dropout': dropout_rng},
+ jnp.ones(input_shape))
model_state, params = variables.pop('params')
return params, model_state
diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
index 0250206a6..b10d4056d 100644
--- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
+++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
@@ -234,10 +234,11 @@ def init_model_fn(
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
- initial_variables = jax.jit(self._eval_model.init)(
- rng,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
+ params_rng, dropout_rng = jax.random.split(rng)
+ initial_variables = jax.jit(
+ self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32))
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py
index 1d3e5d2c7..ef8667660 100644
--- a/datasets/dataset_setup.py
+++ b/datasets/dataset_setup.py
@@ -291,9 +291,22 @@ def download_criteo1tb(data_dir,
stream=True)
all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip')
- with open(all_days_zip_filepath, 'wb') as f:
- for chunk in download_request.iter_content(chunk_size=1024):
- f.write(chunk)
+ download = True
+ if os.path.exists(all_days_zip_filepath):
+ while True:
+ overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
+ all_days_zip_filepath)).lower()
+ if overwrite in ['y', 'n']:
+ break
+ logging.info('Invalid response. Try again.')
+ if overwrite == 'n':
+ logging.info(f'Skipping download to {all_days_zip_filepath}')
+ download = False
+
+ if download:
+ with open(all_days_zip_filepath, 'wb') as f:
+ for chunk in download_request.iter_content(chunk_size=1024):
+ f.write(chunk)
unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}'
logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}')
@@ -679,6 +692,7 @@ def main(_):
if any(s in tmp_dir for s in bad_chars):
raise ValueError(f'Invalid temp_dir: {tmp_dir}.')
data_dir = os.path.abspath(os.path.expanduser(data_dir))
+ tmp_dir = os.path.abspath(os.path.expanduser(tmp_dir))
logging.info('Downloading data to %s...', data_dir)
if FLAGS.all or FLAGS.criteo1tb:
diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh
index c5352df90..5c5a6aa49 100644
--- a/docker/scripts/startup.sh
+++ b/docker/scripts/startup.sh
@@ -157,8 +157,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \
"wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \
"librispeech_deepspeech" "librispeech_conformer" "mnist" \
- "conformer_layernorm" "conformer_attention_temperature" \
- "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
+ "librispeech_conformer_layernorm" "librispeech_conformer_attention_temperature" \
+ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
"librispeech_deepspeech_tanh" \
"librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug"
"fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size")
diff --git a/reference_algorithms/target_setting_algorithms/get_batch_size.py b/reference_algorithms/target_setting_algorithms/get_batch_size.py
index 4c7c96241..3bdc90f36 100644
--- a/reference_algorithms/target_setting_algorithms/get_batch_size.py
+++ b/reference_algorithms/target_setting_algorithms/get_batch_size.py
@@ -15,6 +15,8 @@ def get_batch_size(workload_name):
return 512
elif workload_name == 'imagenet_vit':
return 1024
+ elif workload_name == 'imagenet_vit_glu':
+ return 512
elif workload_name == 'librispeech_conformer':
return 256
elif workload_name == 'librispeech_deepspeech':
diff --git a/submission_runner.py b/submission_runner.py
index fad60e48b..6381c97f1 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -387,7 +387,9 @@ def train_once(
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])
-
+ goals_reached = (
+ train_state['validation_goal_reached'] and
+ train_state['test_goal_reached'])
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time