Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 17, 2024
1 parent f574bf0 commit 6cb1e05
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 41 deletions.
6 changes: 3 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def __call__(self,
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=use_running_average_bn,
Expand Down
3 changes: 2 additions & 1 deletion algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,7 @@ def _eval_model_on_split(self,
for _ in range(num_batches):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __call__(self,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=use_running_average_bn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ def _eval_model_on_split(self,
update_batch_norm=False)
weights = batch.get('weights')
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,19 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn):
def __call__(self,
inputs,
input_paddings,
update_batch_norm,
use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if use_running_average_bn:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

Expand All @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag
keepdims=True)

var = sum_vv / count_v

if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding
Expand Down Expand Up @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn)
inputs = BatchNorm(config)(inputs,
input_paddings,
update_batch_norm,
use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand Down Expand Up @@ -631,12 +648,12 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
Expand Down Expand Up @@ -673,7 +690,11 @@ def __call__(self,

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn)
outputs = ConformerBlock(config)(outputs,
output_paddings,
train,
update_batch_norm,
use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _get_mask(self,
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
[batch_size, 1])
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights <
multiplicity_tensor).astype(jnp.int32)
multiplicity_weights = (multiplicity_weights
< multiplicity_tensor).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings):
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ def greedy_decode(
idxs = torch.arange(
fin_result.numel(), device=result.device).view(*fin_result.shape)
mask = torch.arange(
fin_result.shape[1], device=result.device).view(
1, -1) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.shape[1],
device=result.device).view(1, -1) < result.count_nonzero(dim=1).view(
-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = fin_result == 0
return fin_result, padding
Expand Down Expand Up @@ -329,9 +330,7 @@ def _eval_model_on_split(self,
'word_errors': word_errors,
'num_words': num_words,
}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
4 changes: 1 addition & 3 deletions algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def _eval_model_on_split(self,
batch,
model_state,
per_device_model_rngs)
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}

return self._normalize_eval_metrics(num_examples, total_metrics)
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ def forward(self,
# not the remaining zero elements.
if attn_mask is not None:
raise ValueError('Attention mask has to be None for decode == True.')
attn_mask = (torch.arange(max_len, device=k.device) >=
cache_index).reshape(1, max_len)
attn_mask = (torch.arange(max_len, device=k.device)
>= cache_index).reshape(1, max_len)

# Update sequence length to account for complete sequence.
seq_len = k.size(1)
Expand Down

0 comments on commit 6cb1e05

Please sign in to comment.