Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformer OOM fix #549

Merged
merged 3 commits into from
Oct 20, 2023
Merged

Conversation

chandramouli-sastry
Copy link
Contributor

Contains changes to fix #497

  • I replaced the lazy linear layers with linear layers.
  • Along with few minor changes/improvements throughout the conformer model.py, I configured the scaled-dot-product-attention backend to use the math backend and this helped fix the out of memory error.
  • I updated the comparator and confirmed that the jax/pytorch implementations are identical:
    image
  • Running with torch.compile=True still gives OOM and this only fixes OOM for torch.compile=False (tested with NAdamW and AdamW) and I disabled compilation for librispeech_conformers until this can be fixed.
  • For reference, I also profiled the run for 500 steps -- not sure if this is in the ballpark :)
image

@chandramouli-sastry chandramouli-sastry requested a review from a team as a code owner October 15, 2023 01:14
@github-actions
Copy link

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

torch.backends.cudnn.benchmark = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious that math sdp was more memory efficient than memory efficient attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is likely more memory efficient but it seems like it requires a different memory configuration setting to sidestep the requirement of torch.cuda.empty_cache() after every update step.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i feel like this is some corner case sdpa bug but should be fine to merge still

@@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0):
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim, output_channels=encoder_dim)

self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
self.linear = nn.Linear(
in_features=self.encoder_dim * num_bins // 4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reasoning for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of the two subsampling layers reduce the number mel-spectrogram features by half.

self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)

def _scaled_in_proj_weight(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i assume that the new implementation is numerically equivalent, SDPA is probably the right bet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried the default multihead self-attention without query scaler and it still couldn't run successfully without adjusting the attention backends. If this is useful for your debugging, I can create a separate branch with this setup?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the previous implementation was quite long sdpa is something that's maintained in pytorch so this actually makes things better

Copy link
Contributor Author

@chandramouli-sastry chandramouli-sastry Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation was long because it was extending the nn.MultiheadAttention by changing the in-projection weights/biases in the forward pass but the attention was still managed entirely by pytorch.

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Oct 19, 2023

Confirmed that this works for 60K run:

(60000, {'train/ctc_loss': Array(0.09971161, dtype=float32), 'train/wer': 0.036610075914423744, 'validation/ctc_loss': 
Array(0.30221355, don/wer': type=float32), 'validation/wer': 0.08587637121438702, 'validation/num_examples': 5348,  'test/ctc_loss': Array(0.16302903, dtype=float32), 'teest/num_st/wer': 0.05093237974018627, 'test/num_examples': 2472,  'score': 46081.4025554657, 'total_duration': 50274.87093257904, 'accumulated_submissted_evalion_time':  46081.4025554657, 'accumulated_eval_time': 4190.642651796341, 'accumulated_logging_time': 1.6540186405181885,  'global_step': 60000c_loss':, 'preemption_count': 0})], 'global_step': 60000} 
 r': 0.08I1019 15:00:12.858658 139843127400256 submission_runner.py:550] Timing: 46081.4025554657                                                  um_exampI1019 15:00:12.858741 139843127400256 submission_runner.py:552] Total number of evals: 33                                                  l_time':I1019 15:00:12.858824 139843127400256 submission_runner.py:553] ====================

@priyakasimbeg priyakasimbeg self-requested a review October 20, 2023 04:35
@priyakasimbeg priyakasimbeg merged commit 25fb3a0 into mlcommons:dev Oct 20, 2023
16 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Oct 20, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants