Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Qwen2.5] LoRA with SFT seems to be stuck forever with DeepSpeed #2891

Open
sayakpaul opened this issue Feb 18, 2025 · 0 comments
Open

[Qwen2.5] LoRA with SFT seems to be stuck forever with DeepSpeed #2891

sayakpaul opened this issue Feb 18, 2025 · 0 comments
Labels
🚀 deepspeed Related to deepspeed ⚡ PEFT Related to PEFT 🏋 SFT Related to SFT

Comments

@sayakpaul
Copy link
Member

Training command:

accelerate launch --config_file=deepspeed_zero3.yaml train.py \
  --dataset_name diffusers-internal-dev/ShotDEAD-single-shard \
  --model_name_or_path $MODEL_NAME \
  --attn_implementation "sdpa" \
  --per_device_train_batch_size 4 \
  --gradient_accumulation_steps 4 \
  --output_dir $OUTPUT_DIR \
  --bf16 \
  --use_peft \
  --torch_dtype bfloat16 \
  --gradient_checkpointing
train.py
"""
Adapted from 
https://github.com/huggingface/trl/blob/822653824bf084bc6c042cf0e759f86187c92569/examples/scripts/sft_vlm.py
"""

import torch
from datasets import load_dataset
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image 
import io

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)


if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    ################
    # Model, Tokenizer & Processor
    ################
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    processor = AutoProcessor.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )

    ################
    # Create a data collator to encode text and image pairs
    ################
    def collator_fn(examples):
        # Get the texts and images, and apply the chat template
        texts = [
            processor.apply_chat_template(example["messages"], tokenize=False) 
            for example in examples
        ]
        images = [Image.open(io.BytesIO(example["image"])).convert("RGB") for example in examples]

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
        print(f"{batch.keys()=}")

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100  #
        # Ignore the image token index in the loss computation (model specific)
        image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
        labels[labels == image_token_id] = -100
        batch["labels"] = labels

        return batch

    ################
    # Dataset
    ################
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)["train"]
    splits = dataset.train_test_split(0.1)
    train, val = splits["train"], splits["test"]

    def filter(example):
        can_load = False
        try:
            _ = Image.open(io.BytesIO(example["image"]))
            can_load = True
        except: 
            pass 
        return can_load

    train = train.filter(filter)
    val = val.filter(filter)
    
    ################
    # Training
    ################
    model_args.lora_modules_to_save = ["lm_head", "embed_token"]
    model_args.lora_target_modules = "all-linear"
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        data_collator=collator_fn,
        train_dataset=train,
        eval_dataset=val,
        processing_class=processor.tokenizer,
        peft_config=get_peft_config(model_args),
    )

    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)
        if trainer.accelerator.is_main_process:
            processor.push_to_hub(training_args.hub_model_id)

Referenced from:
https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py

I am on latest trl but using the latest main installations of peft and transformers. Tested this on 8xH100s.

@github-actions github-actions bot added ⚡ PEFT Related to PEFT 🚀 deepspeed Related to deepspeed 🏋 SFT Related to SFT labels Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🚀 deepspeed Related to deepspeed ⚡ PEFT Related to PEFT 🏋 SFT Related to SFT
Projects
None yet
Development

No branches or pull requests

1 participant