Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Feb 6, 2024
1 parent 0dced90 commit 572cebf
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DeepspeechConfig:
use_tanh: bool = False
layernorm_everywhere: bool = False


class Subsample(nn.Module):
"""Module to perform strided convolution in order to subsample inputs.
Expand Down Expand Up @@ -161,7 +162,6 @@ def __call__(self, inputs, paddings, train):
else:
outputs = nn.relu(outputs)


# Computing correct paddings post input convolution.
input_length = paddings.shape[1]
stride = self.filter_stride[0]
Expand Down Expand Up @@ -196,9 +196,11 @@ def __call__(self, inputs, input_paddings=None, train=False):
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs,
input_paddings,
train)
inputs = nn.Dense(
config.encoder_dim,
use_bias=True,
Expand Down Expand Up @@ -436,9 +438,11 @@ def __call__(self, inputs, input_paddings, train):
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs,
input_paddings,
train)
output = CudnnLSTM(
features=config.encoder_dim // 2,
bidirectional=config.bidirectional,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def init_model_fn(
layernorm_everywhere=self.layernorm_everywhere,
freq_mask_count=self.freq_mask_count,
time_mask_count=self.time_mask_count,
)
)
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
Expand Down Expand Up @@ -83,7 +83,7 @@ def use_tanh(self) -> bool:
def enable_residual_connections(self) -> bool:
return True

@property
@property
def enable_decoder_layer_norm(self) -> bool:
return True

Expand Down Expand Up @@ -114,7 +114,8 @@ def enable_residual_connections(self) -> bool:
return False


class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload):
class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
):

@property
def eval_batch_size(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __init__(self, config: DeepspeechConfig):
self.conv1 = Conv2dSubsampling(
input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim, output_channels=encoder_dim, use_tanh=config.use_tanh)
input_channels=encoder_dim,
output_channels=encoder_dim,
use_tanh=config.use_tanh)

self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)

Expand Down Expand Up @@ -213,16 +215,16 @@ def forward(self, inputs, input_paddings):
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else: # batchnorm
else: # batchnorm
inputs = self.normalization_layer(inputs, input_paddings)

inputs = self.lin(inputs)

if self.config.use_tanh:
inputs = F.tanh(inputs)
else:
inputs = F.relu(inputs)

inputs = inputs * padding_mask
inputs = self.dropout(inputs)

Expand Down Expand Up @@ -289,8 +291,8 @@ def __init__(self, config: DeepspeechConfig):
self.normalization_layer = nn.LayerNorm(config.encoder_dim)
else:
self.normalization_layer = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)
config.batch_norm_momentum,
config.batch_norm_epsilon)

if bidirectional:
self.lstm = nn.LSTM(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def use_tanh(self) -> bool:
def enable_residual_connections(self) -> bool:
return True

@property
@property
def enable_decoder_layer_norm(self) -> bool:
return True

Expand Down Expand Up @@ -122,7 +122,8 @@ def enable_residual_connections(self) -> bool:
return False


class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload):
class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
):

@property
def eval_batch_size(self) -> int:
Expand All @@ -142,4 +143,4 @@ def freq_mask_count(self) -> int:

@property
def time_mask_count(self) -> int:
return 15
return 15

0 comments on commit 572cebf

Please sign in to comment.