-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add DPO and SFT of TRL support in Gaudi and example
Signed-off-by: Wang, Yi A <[email protected]>
- Loading branch information
Showing
8 changed files
with
1,148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model | ||
|
||
## Prerequisites | ||
|
||
Install all the dependencies in the `requirements.txt`: | ||
|
||
``` | ||
$ pip install -U -r requirements.txt | ||
``` | ||
|
||
## Training | ||
|
||
There were two main steps to the DPO training process: | ||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: | ||
- `python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py --training_args.output_dir="sft_output" --training-args.report_to none` | ||
1. Run the DPO trainer using the model saved by the previous step: | ||
- `python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py --model_name_or_path="sft_output/final_merged_checkpoint" --output_dir="dpo_output" --report_to=none` | ||
|
||
|
||
## Running the model | ||
|
||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: | ||
|
||
```py | ||
import torch | ||
from peft import AutoPeftModelForCausalLM | ||
|
||
|
||
model = AutoPeftModelForCausalLM.from_pretrained( | ||
"dpo_output", | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
model.generate(...) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
# 0. imports | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from datasets import Dataset, load_dataset | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | ||
|
||
from optimum.habana import GaudiConfig, GaudiTrainingArguments | ||
from optimum.habana.trl import GaudiDPOTrainer | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The arguments for the DPO training script. | ||
""" | ||
|
||
# data parameters | ||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) | ||
|
||
# training parameters | ||
model_name_or_path: Optional[str] = field( | ||
default="../sft/results/final_checkpoint", | ||
metadata={"help": "the location of the SFT model name or path"}, | ||
) | ||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) | ||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) | ||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) | ||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) | ||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) | ||
|
||
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) | ||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) | ||
gradient_accumulation_steps: Optional[int] = field( | ||
default=4, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=False, metadata={"help": "whether to use gradient checkpointing"} | ||
) | ||
|
||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) | ||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) | ||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) | ||
|
||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) | ||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) | ||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) | ||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) | ||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) | ||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) | ||
|
||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) | ||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) | ||
|
||
# instrumentation | ||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) | ||
report_to: Optional[str] = field( | ||
default="wandb", | ||
metadata={ | ||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' | ||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' | ||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' | ||
}, | ||
) | ||
# debug argument for distributed training | ||
ignore_bias_buffers: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" | ||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" | ||
}, | ||
) | ||
|
||
|
||
def get_stack_exchange_paired( | ||
data_dir: str = "data/rl", | ||
sanity_check: bool = False, | ||
cache_dir: str = None, | ||
num_proc=24, | ||
) -> Dataset: | ||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. | ||
The dataset is converted to a dictionary with the following structure: | ||
{ | ||
'prompt': List[str], | ||
'chosen': List[str], | ||
'rejected': List[str], | ||
} | ||
Prompts are structured as follows: | ||
"Question: " + <prompt> + "\n\nAnswer: " | ||
""" | ||
dataset = load_dataset( | ||
"lvwerra/stack-exchange-paired", | ||
split="train", | ||
cache_dir=cache_dir, | ||
data_dir=data_dir, | ||
) | ||
original_columns = dataset.column_names | ||
|
||
if sanity_check: | ||
dataset = dataset.select(range(min(len(dataset), 1000))) | ||
|
||
def return_prompt_and_responses(samples) -> Dict[str, str]: | ||
return { | ||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], | ||
"chosen": samples["response_j"], | ||
"rejected": samples["response_k"], | ||
} | ||
|
||
return dataset.map( | ||
return_prompt_and_responses, | ||
batched=True, | ||
num_proc=num_proc, | ||
remove_columns=original_columns, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
# 1. initialize training arguments: | ||
training_args = GaudiTrainingArguments( | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
per_device_eval_batch_size=script_args.per_device_eval_batch_size, | ||
max_steps=script_args.max_steps, | ||
logging_steps=script_args.logging_steps, | ||
save_steps=script_args.save_steps, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
learning_rate=script_args.learning_rate, | ||
evaluation_strategy="steps", | ||
eval_steps=script_args.eval_steps, | ||
output_dir=script_args.output_dir, | ||
report_to=script_args.report_to, | ||
lr_scheduler_type=script_args.lr_scheduler_type, | ||
warmup_steps=script_args.warmup_steps, | ||
optim=script_args.optimizer_type, | ||
bf16=True, | ||
remove_unused_columns=False, | ||
run_name="dpo_llama2", | ||
use_habana=True, | ||
use_lazy_mode=True, | ||
use_hpu_graphs_for_training=True, | ||
use_hpu_graphs_for_inference=True, | ||
) | ||
# 2. load a pretrained model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model.config.use_cache = False | ||
|
||
if script_args.ignore_bias_buffers: | ||
# torch distributed hack | ||
model._ddp_params_and_buffers_to_ignore = [ | ||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool | ||
] | ||
|
||
model_ref = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model_ref.config.use_cache = False | ||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# 3. Load the Stack-exchange paired dataset | ||
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) | ||
train_dataset = train_dataset.filter( | ||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length | ||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length | ||
) | ||
|
||
# 4. Load evaluation dataset | ||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) | ||
eval_dataset = eval_dataset.filter( | ||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length | ||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length | ||
) | ||
|
||
peft_config = LoraConfig( | ||
r=script_args.lora_r, | ||
lora_alpha=script_args.lora_alpha, | ||
lora_dropout=script_args.lora_dropout, | ||
target_modules=[ | ||
"q_proj", | ||
"v_proj", | ||
"k_proj", | ||
"out_proj", | ||
"fc_in", | ||
"fc_out", | ||
"wte", | ||
], | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
gaudi_config = GaudiConfig() | ||
gaudi_config.use_fused_adam = True | ||
gaudi_config.use_fused_clip_norm = True | ||
|
||
# 5. initialize the DPO trainer | ||
dpo_trainer = GaudiDPOTrainer( | ||
model, | ||
model_ref, | ||
gaudi_config=gaudi_config, | ||
args=training_args, | ||
beta=script_args.beta, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
peft_config=peft_config, | ||
max_prompt_length=script_args.max_prompt_length, | ||
max_length=script_args.max_length, | ||
) | ||
|
||
# 6. train | ||
dpo_trainer.train() | ||
|
||
# 7. save | ||
dpo_trainer.save_model(script_args.output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
trl == 0.7.4 | ||
peft == 0.6.2 | ||
datasets | ||
wandb | ||
tyro |
Oops, something went wrong.