Skip to content

Commit

Permalink
deepspeech modeldiffs
Browse files Browse the repository at this point in the history
  • Loading branch information
chandramouli-sastry committed Feb 14, 2024
1 parent 572cebf commit f660a21
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(self, config: DeepspeechConfig):
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.normalization_layer = BatchNorm(
self.bn_normalization_layer = BatchNorm(
dim=config.encoder_dim,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
Expand All @@ -216,7 +216,7 @@ def forward(self, inputs, input_paddings):
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else: # batchnorm
inputs = self.normalization_layer(inputs, input_paddings)
inputs = self.bn_normalization_layer(inputs, input_paddings)

inputs = self.lin(inputs)

Expand Down Expand Up @@ -288,11 +288,11 @@ def __init__(self, config: DeepspeechConfig):
self.bidirectional = bidirectional

if config.layernorm_everywhere:
self.normalization_layer = nn.LayerNorm(config.encoder_dim)
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.normalization_layer = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)
self.bn_normalization_layer = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)

if bidirectional:
self.lstm = nn.LSTM(
Expand All @@ -308,7 +308,7 @@ def forward(self, inputs, input_paddings):
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else:
inputs = self.normalization_layer(inputs, input_paddings)
inputs = self.bn_normalization_layer(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, lengths, batch_first=True, enforce_sorted=False)
Expand Down Expand Up @@ -357,7 +357,7 @@ def __init__(self, config: DeepspeechConfig):
[FeedForwardModule(config) for _ in range(config.num_ffn_layers)])

if config.enable_decoder_layer_norm:
self.ln = nn.LayerNorm(config.encoder_dim)
self.ln = LayerNorm(config.encoder_dim)
else:
self.ln = nn.Identity()

Expand Down
Empty file.
53 changes: 53 additions & 0 deletions tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \
LibriSpeechDeepSpeechTanhWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform

if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))
Empty file.
53 changes: 53 additions & 0 deletions tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \
LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform

if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))
Empty file.
53 changes: 53 additions & 0 deletions tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \
LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform

if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))

0 comments on commit f660a21

Please sign in to comment.