Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SentenceTransformerTrainer consuming unexpectedly large amount of memory #3187

Open
rupeshgx opened this issue Jan 23, 2025 · 0 comments
Open

Comments

@rupeshgx
Copy link

I am trying to fine-tune gte-Qwen2-7B-instruct (a 7B param model) on an H100 GPU that has 80GB memory. Model loading consumes 28GB of memory and I am left with 52GB for training. I am using LORA so only 2.5M params are trainable (out of the total 7B params). However, I hit OOM as soon as training begins (cannot even complete one iteration). I am unable to understand why this is happening. Any advice?

def main():
    # 1. Load a model to finetune
    model = SentenceTransformer(
        model_name_or_path=Alibaba-NLP/gte-Qwen2-7B-instruct,
        tokenizer_kwargs={
            "model_max_length": 512,
            "truncation": True
        }
    )
    # set the max input seq length to 512
    model.max_seq_length = 512

    # set up lora adapter
    # resource: https://github.com/UKPLab/sentence-transformers/releases/tag/v3.3.0
    peft_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )
    model.add_adapter(peft_config)

    # print trainable params
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || "
        f"all params: {all_param} || "
        f"trainable%: {100 * trainable_params / all_param}"
    )

    device = torch.device('cuda:0')
    free, total = torch.cuda.mem_get_info(device)
    mem_total_mb = total / 1024 ** 2
    mem_used_mb = (total - free) / 1024 ** 2
    print(f"Total GPU memory available = {mem_total_mb} MB")
    print(f"Total GPU memory used after loading the model = {mem_used_mb} MB")

    # 2. Load a dataset to finetune on
    dataset = load_dataset(.....)

    # 3. Define the training arguments and loss function
    loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16)

    args = SentenceTransformerTrainingArguments(
        output_dir="...",
        num_train_epochs=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        learning_rate=2e-5,
        warmup_ratio=0.25,
        fp16=True,
        bf16=False,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        save_strategy="steps",
        save_steps=50,
        save_total_limit=2,
        logging_steps=1,
        run_name="fine-tuning",
        # when using DistributedDataParallel, it is recommended to set `dataloader_drop_last=True` to avoid
        # hanging issues with an uneven last batch
        dataloader_drop_last=True,
    )

   # 4. Train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=dataset,
        loss=loss,
    )
    trainer.train()


if __name__ == "__main__":
    main()

Error message:

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 296.00 MiB. GPU 0 has a total capacity of 79.18 GiB of which 172.31 MiB is free. Process 895969 has 79.01 GiB memory in use. Of the allocated memory 76.34 GiB is allocated by PyTorch, and 933.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant