Skip to content

Commit

Permalink
librispeech_conformer now running
Browse files Browse the repository at this point in the history
Still need to test out (a) output losses, (b) speed, and (c) look into
other librispeech.
  • Loading branch information
rka97 committed Jan 9, 2025
1 parent e6037d6 commit 0698e34
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import flax.linen as nn
import jax
from jax import lax
from jax.sharding import NamedSharding, PartitionSpec as P

from algorithmic_efficiency import sharding_utils
import jax.numpy as jnp
import numpy as np
import optax
Expand All @@ -21,7 +24,6 @@
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \
models


class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):

def __init__(self,
Expand Down Expand Up @@ -93,8 +95,16 @@ def init_model_fn(

self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)

# Add sharding
mesh = sharding_utils.get_mesh()
params = jax.tree_map(
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)),
params)
model_state = jax.tree_map(
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)),
model_state)

return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
Expand Down Expand Up @@ -176,6 +186,7 @@ def _build_input_queue(
'targets': (targets.numpy(), target_paddings.numpy()),
}

# Use data_utils.shard_and_maybe_pad_np to handle sharding
padded_batch = data_utils.shard_and_maybe_pad_np(
numpy_batch, padding_value=1.0)
yield padded_batch
Expand Down Expand Up @@ -300,11 +311,16 @@ def greedy_decode(
return hyp, hyp_paddings

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def eval_step_pmapped(
jax.jit,
in_shardings=(
sharding_utils.get_replicated_sharding(), # params
sharding_utils.get_naive_sharding_spec(), # batch
sharding_utils.get_replicated_sharding(), # model_state
sharding_utils.get_replicated_sharding(), # rng
),
out_shardings=sharding_utils.get_naive_sharding_spec(),
static_argnums=(0,))
def _eval_step(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
Expand All @@ -322,13 +338,39 @@ def eval_step_pmapped(
loss = self.loss_fn(batch['targets'], (logits, logit_paddings))

targets, target_paddings = batch['targets']
return self.metrics_bundle.gather_from_model_output(
loss_dict=loss,
decoded=decoded,
decoded_paddings=decoded_paddings,
targets=targets,
target_paddings=target_paddings,
axis_name='batch')
# Convert metrics bundle to dictionary
metrics_dict = {
'loss_per_example': loss['per_example'],
'decoded': decoded,
'decoded_paddings': decoded_paddings,
'targets': targets,
'target_paddings': target_paddings,
'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples']
}
return metrics_dict

def eval_step(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState):
"""Evaluates the model and returns a metrics bundle."""
metrics_dict = self._eval_step(params, batch, model_state, rng)

# Convert dictionary back to metrics bundle
metrics = self.metrics_bundle.single_from_model_output(
loss_dict={
'summed': metrics_dict['loss_per_example'].sum(),
'per_example': metrics_dict['loss_per_example'],
'n_valid_examples': metrics_dict['n_valid_examples'].sum()
},
decoded=metrics_dict['decoded'],
decoded_paddings=metrics_dict['decoded_paddings'],
targets=metrics_dict['targets'],
target_paddings=metrics_dict['target_paddings'])

return metrics

def _eval_model_on_split(self,
split: str,
Expand All @@ -353,10 +395,10 @@ def _eval_model_on_split(self,
metrics_report = None
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
computed_metrics = self.eval_step_pmapped(params,
eval_batch,
model_state,
rng).unreplicate()
computed_metrics = self.eval_step(params,
eval_batch,
model_state,
rng)

if metrics_report is None:
metrics_report = computed_metrics
Expand All @@ -368,15 +410,22 @@ def _eval_model_on_split(self,

return computed_metrics

@functools.partial(
jax.jit,
in_shardings=(
sharding_utils.get_replicated_sharding(), # model_state
),
out_shardings=sharding_utils.get_replicated_sharding(),
static_argnums=(0,)
)
def sync_batch_stats(
self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics and
# we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
return new_model_state
"""Sync batch statistics across replicas."""
# Replace pmean with direct mean across devices
new_batch_stats = jax.tree_map(
lambda x: jnp.mean(x, axis=0),
model_state['batch_stats'])
return model_state.copy({'batch_stats': new_batch_stats})


class LibriSpeechConformerAttentionTemperatureWorkload(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def update_params(workload: spec.Workload,
del eval_results

optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
if hasattr(hyperparameters, 'label_smoothing'):
label_smoothing = hyperparameters.label_smoothing
else:
Expand All @@ -182,7 +181,7 @@ def update_params(workload: spec.Workload,
replicated, # optimizer_state
replicated, # current_param_container
sharded, # batch
sharded, # per_device_rngs
replicated, # rngs
replicated, # grad_clip
replicated # label_smoothing
)
Expand All @@ -206,7 +205,7 @@ def update_params(workload: spec.Workload,
optimizer_state,
current_param_container,
batch,
per_device_rngs,
rng,
grad_clip,
label_smoothing)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
Expand Down

0 comments on commit 0698e34

Please sign in to comment.