Skip to content

Commit

Permalink
Merge pull request #549 from chandramouli-sastry/conformer_fix
Browse files Browse the repository at this point in the history
Conformer OOM fix
  • Loading branch information
priyakasimbeg authored Oct 20, 2023
2 parents 45a7730 + 28a1ff0 commit 25fb3a0
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def forward(self, x):

class Subsample(nn.Module):

def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0):
def __init__(self,
encoder_dim: int = 0,
input_dropout_rate: float = 0.0,
num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
self.input_dropout_rate = input_dropout_rate
Expand All @@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0):
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim, output_channels=encoder_dim)

self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
self.linear = nn.Linear(
in_features=self.encoder_dim * num_bins // 4,
out_features=self.encoder_dim,
bias=True)
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
self.dropout = nn.Dropout(p=self.input_dropout_rate)

Expand Down Expand Up @@ -123,6 +129,7 @@ def __init__(self,
self.kernel = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
self.bias = nn.Parameter(torch.zeros(output_channels))
self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))

def get_same_padding(self, input_shape):
in_height, in_width = input_shape[2:]
Expand Down Expand Up @@ -162,15 +169,11 @@ def forward(self, inputs, paddings):
input_length = paddings.shape[1]
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
padded_paddings = torch.cat([
paddings[:, None, :],
torch.zeros(
size=(paddings.shape[0], 1, pad_len), device=paddings.device)
],
dim=2)
padded_paddings = F.pad(
paddings[:, None, :], (0, pad_len), mode='constant', value=0)
out_padding = F.conv1d(
input=padded_paddings,
weight=torch.ones([1, 1, 1], device=paddings.device),
weight=self.paddings_kernel,
stride=self.filter_stride[:1])
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
Expand All @@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig):
self.config = config

self.ln = LayerNorm(dim=config.encoder_dim)
self.linear1 = nn.LazyLinear(
self.linear1 = nn.Linear(
in_features=config.encoder_dim,
out_features=config.encoder_dim * config.feed_forward_expansion_factor,
bias=True)
self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate)
self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
self.linear2 = nn.Linear(
in_features=config.encoder_dim * config.feed_forward_expansion_factor,
out_features=config.encoder_dim,
bias=True)

if config.feed_forward_residual_dropout_rate is None:
feed_forward_residual_dropout_rate = 0.1
Expand Down Expand Up @@ -253,217 +260,32 @@ def forward(self, inputs):
return inputs * scale


class MHSAwithQS(nn.MultiheadAttention):
# pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name
class MHSAwithQS(nn.Module):

def __init__(self, config: ConformerConfig):
super().__init__(
embed_dim=config.encoder_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout_rate,
bias=True,
batch_first=True)
super().__init__()
self.embed_dim = config.encoder_dim
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout_rate
self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)

def _scaled_in_proj_weight(self):
# Scale the query projection weight.
qs_input = self.in_proj_weight[:self.embed_dim].view(
self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2)
in_proj_queryW_scaled = self.qs(qs_input).transpose(
1, 2).view(*self.in_proj_weight[:self.embed_dim].shape)
in_proj_weight = torch.cat(
[in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]])
return in_proj_weight

def _scaled_in_proj_bias(self):
# Scale the query bias.
in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view(
self.num_heads, self.embed_dim // self.num_heads)).view(-1)
in_proj_bias = torch.cat(
[in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]])
return in_proj_bias

def forward(self,
query,
key,
value,
key_padding_mask=None,
need_weights: bool = True,
attn_mask=None,
average_attn_weights: bool = True):
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ''
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"

if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device))
for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
# Scale the query bias parameter and the query projection weight.
in_proj_weight = self._scaled_in_proj_weight()
in_proj_bias = self._scaled_in_proj_bias()
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
in_proj_weight,
in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else
0 if attn_mask is not None else None)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
f"The fast path was not hit because {why_not_fast_path}")

if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
else:
# Scale the query bias parameter and the query projection weight.
in_proj_weight = self._scaled_in_proj_weight()
in_proj_bias = self._scaled_in_proj_bias()
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
in_proj_weight, in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def forward(self, inputs, key_padding_mask=None):
batch_size, seq_len, embed_dim = inputs.shape
q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2)
q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
out = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=~key_padding_mask[:, None, None],
dropout_p=self.dropout,
).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = self.out_proj(out)
return out


class MultiHeadedSelfAttention(nn.Module):
Expand All @@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig):

def forward(self, outputs, paddings):
outputs = self.ln(outputs)
outputs, _ = self.self_attention(
query=outputs,
key=outputs,
value=outputs,
key_padding_mask=paddings==1,
need_weights=False,
outputs = self.self_attention(
outputs,
key_padding_mask=paddings == 1,
)
outputs = self.dropout(outputs)
return outputs
Expand All @@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig):
self.register_buffer('running_var', running_var)
self.scale = nn.Parameter(torch.zeros(config.encoder_dim))
self.bias = nn.Parameter(torch.zeros(config.encoder_dim))
self.register_buffer('momentum',
torch.FloatTensor([config.batch_norm_momentum]))
self.register_buffer('epsilon',
torch.FloatTensor([config.batch_norm_epsilon]))

self.register_buffer('dim', torch.FloatTensor([config.encoder_dim]))
# self.momentum = config.batch_norm_momentum
# self.epsilon = config.batch_norm_epsilon
# self.dim = config.encoder_dim
self.momentum = config.batch_norm_momentum
self.epsilon = config.batch_norm_epsilon

def forward(self, inputs, input_paddings):
#inputs: NHD
#padding: NH
"""
Alternatively:
inputs[input_paddings==0] = F.batch_norm(
input = inputs[input_paddings==0],
running_mean = self.running_mean,
running_var = self.running_var,
weight = 1+self.scale,
bias = self.bias,
training = self.training,
momentum=1-self.momentum,
eps=self.epsilon
)
inputs.masked_fill(input_paddings[...,None] != 0, 0)
return inputs
"""
mask = 1 - input_paddings[:, :, None]
if self.training:
count = mask.sum()
Expand Down Expand Up @@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig):
else:
input_dropout_rate = config.input_dropout_rate
self.subsample = Subsample(
encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate)
encoder_dim=config.encoder_dim,
input_dropout_rate=input_dropout_rate,
num_bins=preprocessing_config.num_bins)
self.conformers = nn.ModuleList(
[ConformerBlock(config) for _ in range(config.num_encoder_layers)])

Expand Down
Loading

0 comments on commit 25fb3a0

Please sign in to comment.