diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 6037f285f..821902003 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,12 +400,12 @@ The currently eight fixed workloads are: | | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation
**Target** | Test
**Target** | Maximum
**Runtime**
(in secs) | |------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------| -| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123649 | 0.126060 | 21,600 | -| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.7344 | 0.741652 | 10,800 | -| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 111,600
111,600 | -| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.078477
0.1162 | 0.046973
0.068093 |
72,000 | -| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,000 | -| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 80,000 | +| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 | +| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 | +| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 63,008
77,520 | +| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 61,068
55,506 | +| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | +| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | #### Randomized workloads diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 0bc1cd8cc..93f603bca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape)) + params_rng, dropout_rng = jax.random.split(key) + variables = jax.jit( + model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape)) model_state, params = variables.pop('params') return params, model_state diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 0250206a6..b10d4056d 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -234,10 +234,11 @@ def init_model_fn( self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - initial_variables = jax.jit(self._eval_model.init)( - rng, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + params_rng, dropout_rng = jax.random.split(rng) + initial_variables = jax.jit( + self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 1d3e5d2c7..ef8667660 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -291,9 +291,22 @@ def download_criteo1tb(data_dir, stream=True) all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip') - with open(all_days_zip_filepath, 'wb') as f: - for chunk in download_request.iter_content(chunk_size=1024): - f.write(chunk) + download = True + if os.path.exists(all_days_zip_filepath): + while True: + overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( + all_days_zip_filepath)).lower() + if overwrite in ['y', 'n']: + break + logging.info('Invalid response. Try again.') + if overwrite == 'n': + logging.info(f'Skipping download to {all_days_zip_filepath}') + download = False + + if download: + with open(all_days_zip_filepath, 'wb') as f: + for chunk in download_request.iter_content(chunk_size=1024): + f.write(chunk) unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}' logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}') @@ -679,6 +692,7 @@ def main(_): if any(s in tmp_dir for s in bad_chars): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') data_dir = os.path.abspath(os.path.expanduser(data_dir)) + tmp_dir = os.path.abspath(os.path.expanduser(tmp_dir)) logging.info('Downloading data to %s...', data_dir) if FLAGS.all or FLAGS.criteo1tb: diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index c5352df90..5c5a6aa49 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -157,8 +157,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \ "librispeech_deepspeech" "librispeech_conformer" "mnist" \ - "conformer_layernorm" "conformer_attention_temperature" \ - "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ + "librispeech_conformer_layernorm" "librispeech_conformer_attention_temperature" \ + "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") diff --git a/reference_algorithms/target_setting_algorithms/get_batch_size.py b/reference_algorithms/target_setting_algorithms/get_batch_size.py index 4c7c96241..3bdc90f36 100644 --- a/reference_algorithms/target_setting_algorithms/get_batch_size.py +++ b/reference_algorithms/target_setting_algorithms/get_batch_size.py @@ -15,6 +15,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'imagenet_vit': return 1024 + elif workload_name == 'imagenet_vit_glu': + return 512 elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': diff --git a/submission_runner.py b/submission_runner.py index fad60e48b..6381c97f1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -387,7 +387,9 @@ def train_once( train_state['test_goal_reached'] = ( workload.has_reached_test_target(latest_eval_result) or train_state['test_goal_reached']) - + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time