From 0698e3418acd090f9207243acee969dfdd80056d Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 7 Jan 2025 21:18:44 +0000 Subject: [PATCH] librispeech_conformer now running Still need to test out (a) output losses, (b) speed, and (c) look into other librispeech. --- .../librispeech_jax/workload.py | 101 +++++++++++++----- .../nesterov/jax/submission.py | 5 +- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..4bcb711f5 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -6,6 +6,9 @@ import flax.linen as nn import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P + +from algorithmic_efficiency import sharding_utils import jax.numpy as jnp import numpy as np import optax @@ -21,7 +24,6 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ models - class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -93,8 +95,16 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -176,6 +186,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -300,11 +311,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + out_shardings=sharding_utils.get_naive_sharding_spec(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -322,13 +338,39 @@ def eval_step_pmapped( loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': loss['per_example'], + 'decoded': decoded, + 'decoded_paddings': decoded_paddings, + 'targets': targets, + 'target_paddings': target_paddings, + 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics def _eval_model_on_split(self, split: str, @@ -353,10 +395,10 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, + eval_batch, + model_state, + rng) if metrics_report is None: metrics_report = computed_metrics @@ -368,15 +410,22 @@ def _eval_model_on_split(self, return computed_metrics + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,) + ) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) - return new_model_state + """Sync batch statistics across replicas.""" + # Replace pmean with direct mean across devices + new_batch_stats = jax.tree_map( + lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) class LibriSpeechConformerAttentionTemperatureWorkload( diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a24e3baab..6a903fd7d 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -159,7 +159,6 @@ def update_params(workload: spec.Workload, del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -182,7 +181,7 @@ def update_params(workload: spec.Workload, replicated, # optimizer_state replicated, # current_param_container sharded, # batch - sharded, # per_device_rngs + replicated, # rngs replicated, # grad_clip replicated # label_smoothing ) @@ -206,7 +205,7 @@ def update_params(workload: spec.Workload, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs