We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training command:
accelerate launch --config_file=deepspeed_zero3.yaml train.py \ --dataset_name diffusers-internal-dev/ShotDEAD-single-shard \ --model_name_or_path $MODEL_NAME \ --attn_implementation "sdpa" \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 4 \ --output_dir $OUTPUT_DIR \ --bf16 \ --use_peft \ --torch_dtype bfloat16 \ --gradient_checkpointing
""" Adapted from https://github.com/huggingface/trl/blob/822653824bf084bc6c042cf0e759f86187c92569/examples/scripts/sft_vlm.py """ import torch from datasets import load_dataset from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from PIL import Image import io from trl import ( ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config, ) if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} ################ # Model, Tokenizer & Processor ################ torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ################ # Create a data collator to encode text and image pairs ################ def collator_fn(examples): # Get the texts and images, and apply the chat template texts = [ processor.apply_chat_template(example["messages"], tokenize=False) for example in examples ] images = [Image.open(io.BytesIO(example["image"])).convert("RGB") for example in examples] # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) print(f"{batch.keys()=}") # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # # Ignore the image token index in the loss computation (model specific) image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) labels[labels == image_token_id] = -100 batch["labels"] = labels return batch ################ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)["train"] splits = dataset.train_test_split(0.1) train, val = splits["train"], splits["test"] def filter(example): can_load = False try: _ = Image.open(io.BytesIO(example["image"])) can_load = True except: pass return can_load train = train.filter(filter) val = val.filter(filter) ################ # Training ################ model_args.lora_modules_to_save = ["lm_head", "embed_token"] model_args.lora_target_modules = "all-linear" trainer = SFTTrainer( model=model, args=training_args, data_collator=collator_fn, train_dataset=train, eval_dataset=val, processing_class=processor.tokenizer, peft_config=get_peft_config(model_args), ) trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if trainer.accelerator.is_main_process: processor.push_to_hub(training_args.hub_model_id)
Referenced from: https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py
I am on latest trl but using the latest main installations of peft and transformers. Tested this on 8xH100s.
trl
main
peft
transformers
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Training command:
train.py
Referenced from:
https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py
I am on latest
trl
but using the latestmain
installations ofpeft
andtransformers
. Tested this on 8xH100s.The text was updated successfully, but these errors were encountered: