You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since it's a 34B param model and I'm loading it with bf16, I'd expect approx ~1.2 * 34 * 2 GB VRAM usage + a bit extra for the LORA params. However, on my training VM with 192 GB VRAM across 8 GPUs I'm consistently getting OOM errors.
I can just BARELY get it to work when i set context_length to 4096 but this is not ideal.
So what could be happening here? It's as if the model is being loaded into the GPUs as fp32 instead of bf16. But I'm not sure.
Error logs
Typical CUDA OOM error
Expected behavior
VRAM usage 2x smaller than it is currently & train without OOM errors.
The text was updated successfully, but these errors were encountered:
torch.cuda.memory._record_memory_history(trace_alloc_max_entries=100000)
After your training loop, run the following to take and save a memory snapshot (e.g. only on rank 0):
snapshot = torch.cuda.memory._snapshot()
with open("snapshot.pickle", "wb") as f:
pickle.dump(snapshot, f)
Save snapshot.pickle to local, go to https://zdevito.github.io/assets/viz/, and drag snapshot.pickle to the site to visualize the memory trace. Alternatively, use _memory_viz.py to generate the HTML file.
Thanks @HamidShojanazeri ! In the end I just ended up going the full finetuning route instead of PEFT. Was much easier for me to throw more gpus/nodes at the problem than trying to make PEFT run on a single node.
Hi! It seems that a solution has been provided to the issue and there has not been a follow-up conversation for a long time. I will close this issue for now and feel free to reopen it if you have any questions!
System Info
Information
🐛 Describe the bug
I'm running the following:
Since it's a 34B param model and I'm loading it with bf16, I'd expect approx ~1.2 * 34 * 2 GB VRAM usage + a bit extra for the LORA params. However, on my training VM with 192 GB VRAM across 8 GPUs I'm consistently getting OOM errors.
I can just BARELY get it to work when i set context_length to 4096 but this is not ideal.
So what could be happening here? It's as if the model is being loaded into the GPUs as fp32 instead of bf16. But I'm not sure.
Error logs
Typical CUDA OOM error
Expected behavior
VRAM usage 2x smaller than it is currently & train without OOM errors.
The text was updated successfully, but these errors were encountered: