From 8ce8772fbdf156227c159539e6999b6234b8e8c5 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 28 Feb 2024 00:38:20 +0000 Subject: [PATCH 1/3] Fix WMT jax config for decoding --- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index e95ab4c6f..0250206a6 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -121,7 +121,7 @@ def predict_step(self, max_decode_len: int, beam_size: int = 4) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" - config = models.TransformerConfig(deterministic=True, decode=True) + config = replace(self._eval_model.config, decode=True) # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place From 74124d545f8e0f6811c095b61753aa44d28e27d2 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 28 Feb 2024 00:55:44 +0000 Subject: [PATCH 2/3] Fix deepspeech model_state when batchnorm is not used --- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 2 +- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 32896aaa6..8b9c408bc 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -327,7 +327,7 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None: + if model_state is not None and len(model_state) > 0: # Sync batch statistics across replicas before evaluating. model_state = self.sync_batch_stats(model_state) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index b578d4598..644a1e996 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -47,7 +47,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state = variables['batch_stats'] + model_state = variables['batch_stats'] if not self.layernorm_everywhere else {} params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) From a4dff63dc828e8778e47b96aa86238cebdbf3344 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 28 Feb 2024 01:09:37 +0000 Subject: [PATCH 3/3] Fix lint --- .../librispeech_deepspeech/librispeech_jax/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 644a1e996..72263895f 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -47,7 +47,8 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state = variables['batch_stats'] if not self.layernorm_everywhere else {} + model_state = variables[ + 'batch_stats'] if not self.layernorm_everywhere else {} params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes)