-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conformer OOM fix #549
Conformer OOM fix #549
Conversation
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
torch.backends.cudnn.benchmark = False | ||
torch.backends.cuda.enable_flash_sdp(False) | ||
torch.backends.cuda.enable_mem_efficient_sdp(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious that math sdp was more memory efficient than memory efficient attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is likely more memory efficient but it seems like it requires a different memory configuration setting to sidestep the requirement of torch.cuda.empty_cache()
after every update step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i feel like this is some corner case sdpa bug but should be fine to merge still
@@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): | |||
self.conv2 = Conv2dSubsampling( | |||
input_channels=encoder_dim, output_channels=encoder_dim) | |||
|
|||
self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True) | |||
self.linear = nn.Linear( | |||
in_features=self.encoder_dim * num_bins // 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reasoning for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each of the two subsampling layers reduce the number mel-spectrogram features by half.
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) | ||
|
||
def _scaled_in_proj_weight(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i assume that the new implementation is numerically equivalent, SDPA is probably the right bet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried the default multihead self-attention without query scaler and it still couldn't run successfully without adjusting the attention backends. If this is useful for your debugging, I can create a separate branch with this setup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the previous implementation was quite long sdpa is something that's maintained in pytorch so this actually makes things better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation was long because it was extending the nn.MultiheadAttention by changing the in-projection weights/biases in the forward pass but the attention was still managed entirely by pytorch.
Confirmed that this works for 60K run:
|
Contains changes to fix #497