-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch Conformer OOMS some times #497
Comments
@janeyx99 have you seen this kind of non deterministic optimizer OOM? |
this doesn’t look like it’s ooming in the optimizer but rather the backward, no? the fact that it’s nondeterministic is definitely weird…meaning the extra memory can come from anywhere. otherwise i would guess that activations checkpointing would help here. |
I ran the following script which trains the conformer model for 1000 steps and repeats it 10 times to see how often it errors out. All 10 times, the model trained successfully without any error. I ran this inside docker built using the dockerfile on main branch. @priyakasimbeg , could you please let me know if this following script errs out on the VM you got the above error?
|
This is not reproducible anymore after Juhan's fixes in #502 for Criteo1tb memory issues. I believe clearing the cache after evals helped. |
Happened again on git commit 4c38ffb at step 1778 on kasimbeg-2 |
Here is the traceback: https://gist.github.com/priyakasimbeg/35a7e2562ed471aba6d8087da1e65fda. Seems like it happens in the backward pass:
As discussed in torch dev meeting will try to turn off fused / foreach in the nadamw optimizer. |
Oh hey, I’m not sure how changing the optimizer would affect peak memory if the OOM is happening in the backward—could you give some rationale on why? Also, I’m not sure you’re aware already, but a memory snapshot of one training loop would be helpful with debugging. Zach Devito’s blog describes how to capture one https://zdevito.github.io/2022/12/09/memory-traces.html and I’m happy to send you a code snippet if you’re interested in capturing a snapshot. |
Hey @janeyx99 thanks for taking a look at this! |
oh ya! here’s a snippet where you surround the code you want to profile:
Then you can drag the snapshot.pickle file to https://zdevito.github.io/assets/trace.html to see the visualization of active memory usage over time |
@janeyx99 I ran the profiler. The workload ooms after ~4 steps. Also, when I hover over the blocks I see 'block was allocated before _record_history was enabled' even though I moved the |
@msaroufim I tried capturing a trace using a single GPU and torchrun process but I still see 'block was allocated before _record_history was enabled' for most of the blocks. Do you know how I could get some more useful information out of the profiler? |
Tuning max_split_size_mb to 512 seems to have fixed this. Will send out PR |
ah, changing the max_split_size_mb should help with fragmentation. your memory profile staying flat seems to mean that you enabled the profile when no allocations were made, which is odd because running forward and backward and optimizer should need intermediates at the very least. the distributedness may have something to do with it—maybe this is profiling the wrong device… |
On second thought it looks like reducing the max_split_size_mb to 256 increases the submission time by 2x. @janeyx99 I also tried setting the device to 0 and running on a single process instead of the multiple DDP processes but got the same graph. |
Adjusting the scaled-dot-product-attention backend to use the math backend (slower than other alternatives available but it seems to be similar to the default implementation in 1.13) removes OOM errors. I tried several other things to fix this but they weren't effective without this adjustment. I will adjust the traindiffs test accordingly and test this change before creating a PR for this. |
Fixed in #549. |
Hi @priyakasimbeg - is there any additional info related to this being re-opened? |
Upgrading the GPU driver to 535.104.05 seems to resolve the CUDA OOM, so we will upgrade the drivers on the competition hardware and mark this as resolved. We also confirmed per recommendation form @lessw2020 that on pytorch 2.1.0 setting the following option resolves the OOM:
We won't have to use this flag after all with the driver update but just want to document this in case we run into issues in the future. |
Pytorch conformer occasionally OOMS.
Description
Traceback:
Steps to Reproduce
Pytorch version: torch.dev08202023
The text was updated successfully, but these errors were encountered: