Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VRAM too high when using PEFT + FSDP + BF16 #276

Closed
1 of 2 tasks
ramicaza opened this issue Oct 30, 2023 · 3 comments
Closed
1 of 2 tasks

VRAM too high when using PEFT + FSDP + BF16 #276

ramicaza opened this issue Oct 30, 2023 · 3 comments
Assignees
Labels

Comments

@ramicaza
Copy link

ramicaza commented Oct 30, 2023

System Info

Collecting environment information...
PyTorch version: 2.2.0.dev20231028+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:07:37) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-25-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA L4
GPU 1: NVIDIA L4
GPU 2: NVIDIA L4
GPU 3: NVIDIA L4
GPU 4: NVIDIA L4
GPU 5: NVIDIA L4
GPU 6: NVIDIA L4
GPU 7: NVIDIA L4

Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             96
On-line CPU(s) list:                0-95
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:                           7
CPU MHz:                            2200.214
BogoMIPS:                           4400.42
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.5 MiB
L1i cache:                          1.5 MiB
L2 cache:                           48 MiB
L3 cache:                           77 MiB
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.1
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231028+cu118
[pip3] triton==2.1.0
[conda] numpy                     1.26.1                   pypi_0    pypi
[conda] pytorch-triton            2.1.0+6e4932cda8          pypi_0    pypi
[conda] torch                     2.2.0.dev20231028+cu118          pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

I'm running the following:

torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py \
--enable_fsdp --pure_bf16 \
--low_cpu_fsdp \
--use_fast_kernels \
--batch_size_training 1 \
--context_length 16384 \
--use_peft --peft_method lora \
--fsdp_peft_cpu_offload_for_save \
--num_epochs 100 \
--dataset "custom_dataset" --custom_dataset.file "recipe_format_ds.py" \
--model_name CodeLlama-34b-Instruct-hf \
--output_dir full-finetune-test

Since it's a 34B param model and I'm loading it with bf16, I'd expect approx ~1.2 * 34 * 2 GB VRAM usage + a bit extra for the LORA params. However, on my training VM with 192 GB VRAM across 8 GPUs I'm consistently getting OOM errors.
I can just BARELY get it to work when i set context_length to 4096 but this is not ideal.

So what could be happening here? It's as if the model is being loaded into the GPUs as fp32 instead of bf16. But I'm not sure.

Error logs

Typical CUDA OOM error

Expected behavior

VRAM usage 2x smaller than it is currently & train without OOM errors.

@HamidShojanazeri
Copy link
Contributor

@ramicaza we merged this PR that default the mixed precision policy to bf16 now, this should help to some extend.

Overall some memory traces can be very helpful, you can use memory snapshot , using this blog post from Zack https://zdevito.github.io/2022/12/09/memory-traces.html

torch.cuda.memory._record_memory_history(trace_alloc_max_entries=100000)

After your training loop, run the following to take and save a memory snapshot (e.g. only on rank 0):

snapshot = torch.cuda.memory._snapshot()
with open("snapshot.pickle", "wb") as f:
    pickle.dump(snapshot, f)
    
 Save snapshot.pickle to local, go to https://zdevito.github.io/assets/viz/, and drag snapshot.pickle to the site to visualize the memory trace. Alternatively, use _memory_viz.py to generate the HTML file.
 

@ramicaza
Copy link
Author

ramicaza commented Nov 8, 2023

Thanks @HamidShojanazeri ! In the end I just ended up going the full finetuning route instead of PEFT. Was much easier for me to throw more gpus/nodes at the problem than trying to make PEFT run on a single node.

@wukaixingxp wukaixingxp self-assigned this May 31, 2024
@wukaixingxp
Copy link
Contributor

Hi! It seems that a solution has been provided to the issue and there has not been a follow-up conversation for a long time. I will close this issue for now and feel free to reopen it if you have any questions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants