Skip to content

Commit

Permalink
adding initial code drop for llm finetune (#698)
Browse files Browse the repository at this point in the history
* adding initial code drop for llm finetune

* (a) fixing padding issue; (b) masking input tokens for eval dataset; (c) adding support for mlloger

* fix masking bug

* adding more logger support

* bug fix

* fix logging bug and update HP

* adding patch for memmory issue and fused model enablement

* fixing dataset and model links and updating bash script and readme

* Fix eval batch size, add Dockerfile, improve logging, remove unused code

* Fix eval batch size, add Dockerfile, improve logging, remove unused code

* Remove training_step

* renaming directory and adding more HP values to logger

* adding weight decay to TrainingArguments and BLOCK_START BLOCK_STOP

* editing logging to resolve all checker issues

* fix issue in steps_num logging

* updating bash script for GBS=8

---------

Co-authored-by: Michal Futrega <[email protected]>
  • Loading branch information
itayhubara and michal2409 authored Mar 21, 2024
1 parent 2d0e7ae commit 42aaab3
Show file tree
Hide file tree
Showing 10 changed files with 1,262 additions and 0 deletions.
8 changes: 8 additions & 0 deletions llama2_70b_lora/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:24.01-py3
FROM ${FROM_IMAGE_NAME}

WORKDIR /workspace/ft-llm
ADD . /workspace/ft-llm

RUN pip install -r requirements.txt
RUN pip install flash-attn==2.4.1 --no-build-isolation
100 changes: 100 additions & 0 deletions llama2_70b_lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# LoRA benchmark

LoRA benchmark on GPU (Nvidia A100 80GB). Inspired by [this blog post](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) and [this script](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/train.py).


## Setup

Run the following:
```bash
sudo ./run_docker.sh
cd lora
pip install -r requirements.txt
```

> The Docker run command contains `-v /home/regis_huggingface_co/workspace:/root/workspace --workdir /root/workspace`. Feel free to change these flags at your own convenience.
You will also need to run the following to install flash attention:
```
pip install flash-attn --no-build-isolation
```

> For flash attention, make sure that the following command returns 0:
> ```
> ninja --version >/dev/null && echo $?
> ```
> If not, run
> ```
> pip uninstall -y ninja && pip install ninja
> ```
> and install `flash-attn` again.
> More information [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
Make sure to have requested permission for donwloading Llama2 weights on the Hugging Face Hub: https://huggingface.co/meta-llama/Llama-2-7b-hf
Then, you will need to be connected to your Hugging Face account with a read token running:
```
huggingface-cli login
```
Finally please install mlperf logger:
```
git clone https://github.com/mlperf/logging.git mlperf-logging
pip install -e mlperf-logging
```
## Download Data and Model
data can be downloaded from:
[mlperf drive - train data](https://drive.google.com/file/d/1-JgY1mEafcJ7qhggt6UR3OEKAciIPd5s/view?usp=sharing)
[mlperf drive - validation data](https://drive.google.com/file/d/1jrm6Lacrq49AYv0uB_Qy22xRmfPixQvs/view?usp=sharing)
[mlperf drive - llama-v2 model](https://drive.google.com/drive/folders/1sTeuxkPhwkNPKIPFnOLIYCcK53oB3Ypc?usp=sharing)
As defaults the scripts assume the model is under at ```./llama-v2-fused-qkv``` and the both train and validation are under ```dataset``` folder.
## Llama2-70B on 8 devices
Run:
```bash
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
--model_name meta-llama/Llama-2-70b-hf \
--dataset_name "tau/scrolls" --dataset_config_name "gov_report" \
--max_seq_len 8192 \
--bf16 True \
--logging_steps 1 \
--eval_steps 22 \
--output_dir "/tmp/llama-70b" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "input" \
--lr_scheduler_type "cosine" \
--learning_rate 1e-3 \
--warmup_ratio 0.03 \
--use_gradient_checkpointing True \
--use_peft_lora True \
--lora_r 16 \
--lora_alpha 32 \
--lora_dropout 0.1 \
--max_steps 440 \
--use_flash_attn \
--lora_target_modules "q_proj,v_proj,k_proj,o_proj"
```
where the Accelerate config file is [this one](https://github.com/regisss/lora/blob/main/configs/default_config.yaml).

> Using flash attention with `--use_flash_attn` is necessary for training on 8k-token sequences.
Learning curves of such a run can be found here: https://huggingface.co/regisss/test_5/tensorboard


## Evaluation

To run evaluation for summarizing texts, you can run:
- Without LoRA adapter weights:
```
python scripts/eval.py --model_name meta-llama/Llama-2-70b-hf --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
```
- With LoRA adapter weights:
```
python scripts/eval.py --peft_model_name path_to_my_lora_model --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
```
## expected outcome

A clean output (train and eval loss) of a singel run with 440 steps can be found under
```
convergence_example.txt
```
22 changes: 22 additions & 0 deletions llama2_70b_lora/configs/default_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Loading

0 comments on commit 42aaab3

Please sign in to comment.