diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0c04a67d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +__pycache__ +.ipynb_checkpoints diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..08b500a22 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..8bac304ee --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to llama-recipes +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to llama-recipes, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..bbe189a3d --- /dev/null +++ b/LICENSE @@ -0,0 +1,125 @@ +LLAMA 2 COMMUNITY LICENSE AGREEMENT +Llama 2 Version Release Date: July 18, 2023 + +"Agreement" means the terms and conditions for use, reproduction, distribution and +modification of the Llama Materials set forth herein. + +"Documentation" means the specifications, manuals and documentation +accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Licensee" or "you" means you, or your employer or any other person or entity (if +you are entering into this Agreement on such person or entity's behalf), of the age +required under applicable laws, rules or regulations to provide legal consent and that +has legal authority to bind your employer or such other person or entity if you are +entering in this Agreement on their behalf. + +"Llama 2" means the foundational large language models and software and +algorithms, including machine-learning model code, trained model weights, +inference-enabling code, training-enabling code, fine-tuning enabling code and other +elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Llama Materials" means, collectively, Meta's proprietary Llama 2 and +Documentation (and any portion thereof) made available under this Agreement. + +"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you +are an entity, your principal place of business is in the EEA or Switzerland) and Meta +Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +By clicking "I Accept" below or by using or distributing any portion or element of the +Llama Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. + + a. Grant of Rights. You are granted a non-exclusive, worldwide, non- +transferable and royalty-free limited license under Meta's intellectual property or +other rights owned by Meta embodied in the Llama Materials to use, reproduce, +distribute, copy, create derivative works of, and make modifications to the Llama +Materials. + + b. Redistribution and Use. + + i. If you distribute or make the Llama Materials, or any derivative works +thereof, available to a third party, you shall provide a copy of this Agreement to such +third party. + ii. If you receive Llama Materials, or any derivative works thereof, from +a Licensee as part of an integrated end user product, then Section 2 of this +Agreement will not apply to you. + + iii. You must retain in all copies of the Llama Materials that you +distribute the following attribution notice within a "Notice" text file distributed as a +part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, +Copyright (c) Meta Platforms, Inc. All Rights Reserved." + + iv. Your use of the Llama Materials must comply with applicable laws +and regulations (including trade compliance laws and regulations) and adhere to the +Acceptable Use Policy for the Llama Materials (available at +https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into +this Agreement. + + v. You will not use the Llama Materials or any output or results of the +Llama Materials to improve any other large language model (excluding Llama 2 or +derivative works thereof). + +2. Additional Commercial Terms. If, on the Llama 2 version release date, the +monthly active users of the products or services made available by or for Licensee, +or Licensee's affiliates, is greater than 700 million monthly active users in the +preceding calendar month, you must request a license from Meta, which Meta may +grant to you in its sole discretion, and you are not authorized to exercise any of the +rights under this Agreement unless or until Meta otherwise expressly grants you +such rights. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE +LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE +PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY +WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR +FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE +FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING +THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR +USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE +LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, +NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS +AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, +CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN +IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF +ANY OF THE FOREGOING. + +5. Intellectual Property. + + a. No trademark licenses are granted under this Agreement, and in +connection with the Llama Materials, neither Meta nor Licensee may use any name +or mark owned by or associated with the other or any of its affiliates, except as +required for reasonable and customary use in describing and redistributing the +Llama Materials. + + b. Subject to Meta's ownership of Llama Materials and derivatives made by or +for Meta, with respect to any derivative works and modifications of the Llama +Materials that are made by you, as between you and Meta, you are and will be the +owner of such derivative works and modifications. + + c. If you institute litigation or other proceedings against Meta or any entity +(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama +Materials or Llama 2 outputs or results, or any portion of any of the foregoing, +constitutes infringement of intellectual property or other rights owned or licensable +by you, then any licenses granted to you under this Agreement shall terminate as of +the date such litigation or claim is filed or instituted. You will indemnify and hold +harmless Meta from and against any claim by any third party arising out of or related +to your use or distribution of the Llama Materials. + +6. Term and Termination. The term of this Agreement will commence upon your +acceptance of this Agreement or access to the Llama Materials and will continue in +full force and effect until terminated in accordance with the terms and conditions +herein. Meta may terminate this Agreement if you are in breach of any term or +condition of this Agreement. Upon termination of this Agreement, you shall delete +and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the +termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and +construed under the laws of the State of California without regard to choice of law +principles, and the UN Convention on Contracts for the International Sale of Goods +does not apply to this Agreement. The courts of California shall have exclusive +jurisdiction of any dispute arising out of this Agreement. diff --git a/README.md b/README.md new file mode 100644 index 000000000..ff6adee35 --- /dev/null +++ b/README.md @@ -0,0 +1,158 @@ +# Llama 2 Fine-tuning / Inference Recipes and Examples + +The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face). + +Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama). + + +# Table of Contents +1. [Quick start](#quick-start) +2. [Fine-tuning](#fine-tuning) + - [Single GPU](#single-gpu) + - [Multi GPU One Node](#multiple-gpus-one-node) + - [Multi GPU Multi Node](#multi-gpu-multi-node) +3. [Inference](./inference/inference.md) +4. [Model Conversion](#model-conversion-to-hugging-face) +5. [Repository Organization](#repository-organization) +6. [License and Acceptable Use Policy](#license) + + + +# Quick Start + +[Llama 2 Jupyter Notebook](quickstart.ipynb): This jupyter notebook steps you through how to finetune a Llama 2 model on the text summarization task using the [samsum](https://huggingface.co/datasets/samsum). The notebook uses parameter efficient finetuning (PEFT) and int8 quantization to finetune a 7B on a single GPU like an A10 with 24GB gpu memory. + +**Note** All the setting defined in [config files](./configs/) can be passed as args through CLI when running the sctipt, there is no need to change from config files directly. + +**Note** In case need to run PEFT model with FSDP, please make sure to use the PyTorch Nightlies. + +**For more in depth information checkout the following:** + +* [Single GPU Fine-tuning](./docs/single_gpu.md) +* [Multi-GPU Fine-tuning](./docs/mutli_gpu.md) +* [LLM Fine-tuning](./docs/LLM_finetuning.md) +* [Adding custom datasets](./docs/Dataset.md) +* [Inference](./inference/inference.md) +* [FAQs](./docs/FAQ.md) + +## Requirements +To run the examples, make sure to install the requirements using + +```bash + +pip install -r requirements.txt + +``` + +**Please note that the above requirements.txt will install PyTorch 2.0.1 version, in case you want to run FSDP + PEFT, please make sure to install PyTorch nightlies.** + +# Fine-tuning + +For fine-tuning Llama 2 models for your domain-specific use cases recipes for PEFT, FSDP, PEFT+FSDP have been included along with a few test datasets. For details see [LLM Fine-tuning](./docs/LLM_finetuning.md). + +## Single and Multi GPU Finetune + +If you want to dive right into single or multi GPU fine-tuning, run the examples below on a single GPU like A10, T4, V100, A100 etc. +All the parameters in the examples and recipes below need to be further tuned to have desired results based on the model, method, data and task at hand. + +**Note:** +* To change the dataset in the commands below pass the `dataset` arg. Current options for dataset are `grammar_dataset`, `alpaca_dataset`and `samsum_dataset`. A description of the datasets and how to add custom datasets can be found in [Dataset.md](./docs/Dataset.md). For `grammar_dataset`, `alpaca_dataset` please make sure you use the suggested instructions from [here](./docs/single_gpu.md#how-to-run-with-different-datasets) to set them up. + +* Default dataset and other LORA config has been set to `samsum_dataset`. + +* Make sure to set the right path to the model in the [training config](./configs/training.py). + +### Single GPU : + +```bash +#if running on multi-gpu machine +export CUDA_VISIBLE_DEVICES=0 + +python llama_finetuning.py --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model + +``` + +Here we make use of Parameter Efficient Methods (PEFT) as described in the next section. To run the command above make sure to pass the `peft_method` arg which can be set to `lora`, `llama_adapter` or `prefix`. + +**Note** if you are running on a machine with multiple GPUs please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id` + +**Make sure you set [save_model](configs/training.py) in [training.py](configs/training.py) to save the model. Be sure to check the other training settings in [train config](configs/training.py) as well as others in the config folder as needed or they can be passed as args to the training script as well.** + + +### Multiple GPUs One Node: + +**NOTE** please make sure to use PyTorch Nightlies for using PEFT+FSDP . + +```bash + +torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model + +``` + +Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training. + +### Fine-tuning using FSDP Only + +If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. + +```bash + +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned + +``` + +### Multi GPU Multi Node: + +```bash + +sbatch multi_node.slurm +# Change the num nodes and GPU per nodes in the script before running. + +``` +You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md). + + +# Model conversion to Hugging Face +The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library. + +Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with: + +```bash +## Install HuggingFace Transformers from source +pip install git+https://github.com/huggingface/transformers +cd transformers + +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir models_hf/7B +``` + +# Repository Organization +This repository is organized in the following way: + +[configs](configs/): Contains the configuration files for PEFT methods, FSDP, Datasets. + +[docs](docs/): Example recipes for single and multi-gpu fine-tuning recipes. + +[ft_datasets](ft_datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses) + + +[inference](inference/): Includes examples for inference for the fine-tuned models and how to use them safely. + +[model_checkpointing](model_checkpointing/): Contains FSDP checkpoint handlers. + +[policies](policies/): Contains FSDP scripts to provide different policies, such as mixed precision, transformer wrapping policy and activation checkpointing along with any precision optimizer (used for running FSDP with pure bf16 mode). + +[utils](utils/): Utility files for: + +- `train_utils.py` provides training/eval loop and more train utils. + +- `dataset_utils.py` to get preprocessed datasets. + +- `config_utils.py` to override the configs received from CLI. + +- `fsdp_utils.py` provides FSDP wrapping policy for PEFT methods. + +- `memory_utils.py` context manager to track different memory stats in train loop. + +# License +See the License file [here](LICENSE) and Acceptable Use Policy [here](USE_POLICY.md) diff --git a/USE_POLICY.md b/USE_POLICY.md new file mode 100644 index 000000000..4299e1d15 --- /dev/null +++ b/USE_POLICY.md @@ -0,0 +1,49 @@ +# Llama 2 Acceptable Use Policy + +Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). + +## Prohibited Uses +We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to: + +1. Violate the law or others’ rights, including to: + 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: + 1. Violence or terrorism + 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material + 3. Human trafficking, exploitation, and sexual violence + 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. + 5. Sexual solicitation + 6. Any other criminal activity + 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals + 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services + 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices + 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws + 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials + 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system + + + +2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following: + 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State + 2. Guns and illegal weapons (including weapon development) + 3. Illegal drugs and regulated/controlled substances + 4. Operation of critical infrastructure, transportation technologies, or heavy machinery + 5. Self-harm or harm to others, including suicide, cutting, and eating disorders + 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual + + + +3. Intentionally deceive or mislead others, including use of Llama 2 related to the following: + 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation + 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content + 3. Generating, promoting, or further distributing spam + 4. Impersonating another individual without consent, authorization, or legal right + 5. Representing that the use of Llama 2 or outputs are human-generated + 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement +4. Fail to appropriately disclose to end users any known dangers of your AI system + +Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means: + +* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) +* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) +* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) +* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 000000000..83c215b60 --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .peft import lora_config, llama_adapter_config, prefix_config +from .fsdp import fsdp_config +from .training import train_config diff --git a/configs/datasets.py b/configs/datasets.py new file mode 100644 index 000000000..6cb3cf591 --- /dev/null +++ b/configs/datasets.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass + + +@dataclass +class samsum_dataset: + dataset: str = "samsum_dataset" + train_split: str = "train" + test_split: str = "validation" + input_length: int = 2048 + + +@dataclass +class grammar_dataset: + dataset: str = "grammar_dataset" + train_split: str = "ft_datasets/grammar_dataset/gtrain_10k.csv" + test_split: str = "ft_datasets/grammar_dataset/grammar_validation.csv" + input_length: int = 2048 + + +@dataclass +class alpaca_dataset: + dataset: str = "alpaca_dataset" + train_split: str = "train" + test_split: str = "val" + data_path: str = "ft_datasets/alpaca_data.json" \ No newline at end of file diff --git a/configs/fsdp.py b/configs/fsdp.py new file mode 100644 index 000000000..3a9226ede --- /dev/null +++ b/configs/fsdp.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass, field +from typing import ClassVar +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + +@dataclass +class fsdp_config: + mixed_precision: bool=True + use_fp16: bool=False + sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool=True + pure_bf16: bool = True + optimizer: str= "AdamW" + + + \ No newline at end of file diff --git a/configs/peft.py b/configs/peft.py new file mode 100644 index 000000000..cb88f146b --- /dev/null +++ b/configs/peft.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass, field +from typing import ClassVar, List + +@dataclass +class lora_config: + r: int=8 + lora_alpha: int=32 + target_modules: ClassVar[List[str]]= ["q_proj", "v_proj"] + bias= "none" + task_type: str= "CAUSAL_LM" + lora_dropout: float=0.05 + inference_mode: bool = False + +@dataclass +class llama_adapter_config: + adapter_len: int= 10 + adapter_layers: int= 30 + task_type: str= "CAUSAL_LM" + +@dataclass +class prefix_config: + num_virtual_tokens: int=30 + task_type: str= "CAUSAL_LM" \ No newline at end of file diff --git a/configs/training.py b/configs/training.py new file mode 100644 index 000000000..4c50372da --- /dev/null +++ b/configs/training.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from dataclasses import dataclass +from typing import ClassVar + + +@dataclass +class train_config: + model_name: str="PATH/to/LLAMA/7B" + enable_fsdp: bool= False + run_validation: bool=True + batch_size_training: int=4 + num_epochs: int=3 + num_workers_dataloader: int=1 + lr: float=1e-4 + weight_decay: float=0.0 + gamma: float= 0.85 + seed: int=42 + use_fp16: bool=False + mixed_precision: bool=True + val_batch_size: int=1 + dataset = "samsum_dataset" + micro_batch_size: int=4 + peft_method: str = "lora" # None , llama_adapter, prefix + use_peft: bool=False + output_dir: str = "PATH/to/save/PEFT/model" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP + save_optimizer: bool=False # will be used if using FSDP + + + + \ No newline at end of file diff --git a/docs/Dataset.md b/docs/Dataset.md new file mode 100644 index 000000000..68bfab6de --- /dev/null +++ b/docs/Dataset.md @@ -0,0 +1,68 @@ +# Datasets and Evaluation Metrics + +The provided fine tuning script allows you to select between three datasets by passing the `dataset` arg to the `llama_finetuning.py` script. The current options are `grammar_dataset`, `alpaca_dataset`and `samsum_dataset`. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses) + +* [grammar_dataset](https://huggingface.co/datasets/jfleg) contains 150K pairs of english sentences and possible corrections. +* [alpaca_dataset](https://github.com/tatsu-lab/stanford_alpaca) provides 52K instruction-response pairs as generated by `text-davinci-003`. +* [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries. + +## Adding custom datasets + +The list of available datasets can easily be extended with custom datasets by following these instructions. + +Each dataset has a corresponding configuration (dataclass) in [configs/dataset.py](../configs/dataset.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc. + +Additionally, there is a preprocessing function for each dataset in the [ft_datasets](../ft_datasets) folder. +The returned data of the dataset needs to be consumable by the forward method of the fine-tuned model by calling ```model(**data)```. +For CausalLM models this usually means that the data needs to be in the form of a dictionary with "input_ids", "attention_mask" and "labels" fields. + +To add a custom dataset the following steps need to be performed. + +1. Create a dataset configuration after the schema described above. Examples can be found in [configs/dataset.py](../configs/dataset.py). +2. Create a preprocessing routine which loads the data and returns a PyTorch style dataset. The signature for the preprocessing function needs to be (dataset_config, tokenizer, split_name) where split_name will be the string for train/validation split as defined in the dataclass. +3. Register the dataset name and preprocessing function by inserting it as key and value into the DATASET_PREPROC dictionary in [utils/dataset_utils.py](../utils/dataset_utils.py) +4. Set dataset field in training config to dataset name or use --dataset option of the llama_finetuning.py training script. + +## Application +Below we list other datasets and their main use cases that can be used for fine tuning. + +### Q&A these can be used for evaluation as well +- [MMLU](https://huggingface.co/datasets/lukaemon/mmlu/viewer/astronomy/validation) +- [BoolQ](https://huggingface.co/datasets/boolq) +- [NarrativeQA](https://huggingface.co/datasets/narrativeqa) +- [NaturalQuestions](https://huggingface.co/datasets/natural_questions) (closed-book) +- [NaturalQuestions](https://huggingface.co/datasets/openbookqa) (open-book) +- [QuAC](https://huggingface.co/datasets/quac) +- [HellaSwag](https://huggingface.co/datasets/hellaswag) +- [OpenbookQA](https://huggingface.co/datasets/openbookqa) +- [TruthfulQA](https://huggingface.co/datasets/truthful_qa) ( can be helpful for fact checking/ misinformation of the model) + + +### instruction finetuning +- [Alpaca](https://huggingface.co/datasets/yahma/alpaca-cleaned) 52k instruction tuning +- [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k) 15k 15k instruction tuning + + +### simple text generation for quick tests +[English](https://huggingface.co/datasets/Abirate/english_quotes) quotes 2508 Multi-label text classification, text generation + + +### Reasoning used mostly for evaluation of LLMs +- [bAbI](https://research.facebook.com/downloads/babi/) +- [Dyck](https://huggingface.co/datasets/dyk) +- [GSM8K](https://huggingface.co/datasets/gsm8k) +- [MATH](https://github.com/hendrycks/math) +- [APPS](https://huggingface.co/datasets/codeparrot/apps) +- [HumanEval](https://huggingface.co/datasets/openai_humaneval) +- [LSAT](https://huggingface.co/datasets/dmayhem93/agieval-lsat-ar) +- [Entity matching](https://huggingface.co/datasets/lighteval/EntityMatching) + +### Toxicity evaluation +- [Real_toxic_prompts](https://huggingface.co/datasets/allenai/real-toxicity-prompts) + +### Bias evaluation +- [Crows_pair](https://huggingface.co/datasets/crows_pairs) gender bias +- [WinoGender] gender bias + +### Useful Links +More information on evaluation dataset can be found in [HELM](https://crfm.stanford.edu/helm/latest/) diff --git a/docs/FAQ.md b/docs/FAQ.md new file mode 100644 index 000000000..0b1c280f0 --- /dev/null +++ b/docs/FAQ.md @@ -0,0 +1,19 @@ +# FAQ + +Here we discuss frequently asked questions that may occur and we found useful along the way. + +1. Does FSDP support mixed precision in one FSDP unit? Meaning, in one FSDP unit some of the parameters are in Fp16/Bf16 and others in FP32. + + FSDP requires each FSDP unit to have consistent precision, so this case is not supported at this point. It might be added in future but no ETA at the moment. + +2. How does FSDP handles mixed grad requirements? + + FSDP does not support mixed `require_grad` in one FSDP unit. This means if you are planning to freeze some layers, you need to do it on the FSDP unit level rather than model layer. For example, let us assume our model has 30 decoder layers and we want to freeze the bottom 28 layers and only train 2 top transformer layers. In this case, we need to make sure `require_grad` for the top two transformer layers are set to `True`. + +3. How do PEFT methods work with FSDP in terms of grad requirements/layer freezing? + + We wrap the PEFT modules separate from the transfromer layer in auto_wrapping policy, that would result in PEFT models having `require_grad=True` while the rest of the model is `require_grad=False`. + +4. Can I add custom datasets? + + Yes, you can find more information on how to do that [here](Dataset.md). diff --git a/docs/LLM_finetuning.md b/docs/LLM_finetuning.md new file mode 100644 index 000000000..369f9119b --- /dev/null +++ b/docs/LLM_finetuning.md @@ -0,0 +1,66 @@ +## LLM Finetuning + +Here we discuss finetuning Llama 2 with a couple of different recipes. We will cover two scenarios here: + + +## 1. **Parameter Efficient Model Finetuning** + This helps make the fine-tuning process more affordable even on 1 consumer grade GPU. These methods enable us to keep the whole model frozen and to just add tiny learnable parameters/ layers into the model. In this way, we just train a very tiny portion of the parameters. The most famous method in this category is [LORA](https://arxiv.org/pdf/2106.09685.pdf), LLaMA Adapter and Prefix-tuning. + + +These methods will address three aspects: + + +- **Cost of full finetuning** – these methods only train a small set of extra parameters instead of the full model, this makes it possible to run these on consumer GPUs. + +- **Cost of deployment** – for each fine-tuned downstream model we need to deploy a separate model; however, when using these methods, only a small set of parameters (few MB instead of several GBs) of the pretrained model can do the job. In this case, for each task we only add these extra parameters on top of the pretrained model so pretrained models can be assumed as backbone and these parameters as heads for the model on different tasks. + +- **Catastrophic forgetting** — these methods also help with forgetting the first task that can happen in finetunings. + +HF [PEFT](https://github.com/huggingface/peft) library provides an easy way of using these methods which we make use of here. Please read more [here](https://huggingface.co/blog/peft). + + + +## 2. **Full/ Partial Parameter Finetuning** + +Full parameter finetuning has its own advantages, in this method there are multiple strategies that can help: + +- Keep the pretrained model frozen and only finetune the task head for example, the classifier model. + + +- Keep the pretrained model frozen and add a few fully connected layers on the top. + + +- Finetuning on all the layers. + +You can also keep most of the layers frozen and only finetune a few layers. There are many different techniques to choose from to freeze/unfreeze layers based on different criteria. + +
+ Image 1 + Image 2 + Image 3 +
+ + + +In this scenario depending on the model size, you might need to go beyond one GPU, especially if your model does not fit into one GPU for training. In this case Llama 2 7B parameter wont fit into one gpu. +The way you want to think about it is, you would need enough GPU memory to keep model parameters, gradients and optimizer states. Where each of these, depending on the precision you are training, can take up multiple times of your parameter count x precision( depending on if its fp32/ 4 bytes, fp16/2 bytes/ bf16/2 bytes). +For example AdamW optimizer keeps 2 parameters for each of your parameters and in many cases these are kept in fp32. This implies that depending on how many layers you are training/ unfreezing your GPU memory can grow beyond one GPU. + +**FSDP (FUlly Sharded Data Parallel)** + + +Pytorch has the FSDP package for training models that do not fit into one GPU. FSDP lets you train a much larger model with the same amount of resources. Prior to FSDP was DDP (Distributed Data Parallel) where each GPU was holding a full replica of the model and would only shard the data. At the end of backward pass it would sync up the gradients. + +FSDP extends this idea, not only sharding the data but also model parameters, gradients and optimizer states. This means each GPU will only keep one shard of the model. This will result in huge memory savings that enable us to fit a much larger model into the same number of GPU. As an example in DDP the most you could fit into a GPU with 16GB memory is a model around 700M parameters. So, suppose you had 4 GPUs, in this case even though you access 4 GPUs, you still can't scale beyond the model size that can fit into one GPU. However with FSDP you can fit a 3B model into 4 GPUs, > 4x larger model. + + +Please read more on FSDP here. + + +To boost the performance of fine-tuning with FSDP, we can make use a number of features such as: + +- **Mixed Precision** which in FSDP is much more flexible compared to Autocast. It gives user control over setting precision for model parameters, buffers and gradients. + +- **Activation Checkpointing** which is a technique to save memory by discarding the intermediate activation in forward pass instead of keeping it in the memory with the cost recomputing them in the backward pass. FSDP Activation checkpointing is shard aware meaning we need to apply it after wrapping the model with FSDP. In our script we are making use of that. + +- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost. diff --git a/docs/images/feature-based_FN.png b/docs/images/feature-based_FN.png new file mode 100644 index 000000000..7bc25ccda Binary files /dev/null and b/docs/images/feature-based_FN.png differ diff --git a/docs/images/featurebased_FN_.png b/docs/images/featurebased_FN_.png new file mode 100644 index 000000000..1bbd73014 Binary files /dev/null and b/docs/images/featurebased_FN_.png differ diff --git a/docs/images/full-param-FN.png b/docs/images/full-param-FN.png new file mode 100644 index 000000000..e97792b46 Binary files /dev/null and b/docs/images/full-param-FN.png differ diff --git a/docs/mutli_gpu.md b/docs/mutli_gpu.md new file mode 100644 index 000000000..b0ca2e9f0 --- /dev/null +++ b/docs/mutli_gpu.md @@ -0,0 +1,156 @@ +# Fine-tuning with Multi GPU + +To run fine-tuning on multi-GPUs, we will make use of two packages: + +1. [PEFT](https://huggingface.co/blog/peft) methods and in particular using the Hugging Face [PEFT](https://github.com/huggingface/peft)library. + +2. [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html) which helps us parallelize the training over mutiple GPUs. [More details](LLM_finetuning.md/#2-full-partial-parameter-finetuning). + +Given the combination of PEFT and FSDP, we would be able to fine tune a Llama 2 model on multiple GPUs in one node or multi-node. + +## Requirements +To run the examples, make sure to install the requirements using + +```bash + +pip install -r requirements.txt + +``` + +**Please note that the above requirements.txt will install PyTorch 2.0.1 version, in case you want to run FSDP + PEFT, please make sure to install PyTorch nightlies.** + +## How to run it + +Get access to a machine with mutiple GPUs ( in this case we tested with 4 A100 and A10s). +This runs with the `samsum_dataset` for summarization application by default. + +**Multiple GPUs one node**: + +```bash + +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model + +``` + +The args used in the command above are: + +* `--enable_fsdp` boolean flag to enable FSDP in the script + +* `--use_peft` boolean flag to enable PEFT methods in the script + +* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`. + +We use `torchrun` here to spawn multiple processes for FSDP. + + +### Fine-tuning using FSDP Only + +If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. + +```bash + +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 + +``` + +**Multi GPU multi node**: + +Here we use a slurm script to schedule a job with slurm over multiple nodes. + +```bash + +sbatch multi_node.slurm +# Change the num nodes and GPU per nodes in the script before running. + +``` + +## How to run with different datasets? + +Currenty 4 datasets are supported that can be found in [Datasets config file](../configs/datasets.py). + +* `grammar_dataset` : use this [notebook](../ft_datasets/grammar_dataset/grammar_dataset_process.ipynb) to pull and process theJfleg and C4 200M datasets for grammar checking. + +* `alpaca_dataset` : to get this open source data please download the `aplaca.json` to `ft_dataset` folder. + +```bash +wget -P ft_dataset https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json +``` + +* `samsum_dataset` + +To run with each of the datasets set the `dataset` flag in the command as shown below: + +```bash +# grammer_dataset +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --dataset grammar_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --output_dir Path/to/save/PEFT/model + +# alpaca_dataset + +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --dataset alpaca_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --output_dir Path/to/save/PEFT/model + + +# samsum_dataset + +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --dataset samsum_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --output_dir Path/to/save/PEFT/model + +``` + +## Where to configure settings? + +* [Training config file](../configs/training.py) is the main config file that helps to specify the settings for our run and can be found in [configs folder](../configs/) + +It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings: + +```python + +model_name: str="PATH/to/LLAMA 2/7B" +enable_fsdp: bool= False +run_validation: bool=True +batch_size_training: int=4 +num_epochs: int=3 +num_workers_dataloader: int=2 +lr: float=2e-4 +weight_decay: float=0.0 +gamma: float= 0.85 +use_fp16: bool=False +mixed_precision: bool=True +val_batch_size: int=4 +dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset +micro_batch_size: int=1 +peft_method: str = "lora" # None , llama_adapter, prefix +use_peft: bool=False +output_dir: str = "./ft-output" +freeze_layers: bool = False +num_freeze_layers: int = 1 +quantization: bool = False +save_model: bool = False +dist_checkpoint_root_folder: str="model_checkpoints" +dist_checkpoint_folder: str="fine-tuned" +save_optimizer: bool=False + +``` + +* [Datasets config file](../configs/datasets.py) provides the avaiable options for datasets. + +* [peft config file](../configs/peft.py) provides the suported PEFT methods and respective settings that can be modified. + +* [FSDP config file](../configs/fsdp.py) provides FSDP settings such as: + + * `mixed_precision` boolean flag to specify using mixed precision, defatults to true. + + * `use_fp16` boolean flag to specify using FP16 for mixed precision, defatults to False. We recommond not setting this flag, and only set `mixed_precision` that will use `BF16`, this will help with speed and memory savings while avoiding challenges of scaler accuracies with `FP16`. + + * `sharding_strategy` this specifies the sharding strategy for FSDP, it can be: + * `FULL_SHARD` that shards model parameters, gradients and optimizer states, results in the most memory savings. + + * `SHARD_GRAD_OP` that shards gradinets and optimizer states and keeps the parameters after the first `all_gather`. This reduces communication overhead specially if you are using slower networks more specifically beneficial on multi-node cases. This comes with the trade off of higher memory consumption. + + * `NO_SHARD` this is equivalant to DDP, does not shard model parameters, gradinets or optimizer states. It keeps the full parameter after the first `all_gather`. + + * `HYBRID_SHARD` available on PyTorch Nightlies. It does FSDP within a node and DDP between nodes. It's for multi-node cases and helpful for slower networks, given your model will fit into one node. + +* `checkpoint_type` specifies the state dict checkpoint type for saving the model. `FULL_STATE_DICT` streams state_dict of each model shard from a rank to CPU and assembels the full state_dict on CPU. `SHARDED_STATE_DICT` saves one checkpoint per rank, and enables the re-loading the model in a different world size. + +* `fsdp_activation_checkpointing` enables activation checkpoining for FSDP, this saves siginificant amount of memory with the trade off of recomputing itermediate activations during the backward pass. The saved memory can be re-invested in higher batch sizes to increase the throughput. We recommond you use this option. + +* `pure_bf16` it moves the model to `BFloat16` and if `optimizer` is set to `anyprecision` then optimizer states will be kept in `BFloat16` as well. You can use this option if neccessary. diff --git a/docs/single_gpu.md b/docs/single_gpu.md new file mode 100644 index 000000000..d9ccb7b05 --- /dev/null +++ b/docs/single_gpu.md @@ -0,0 +1,111 @@ +# Fine-tuning with Single GPU + +To run fine-tuning on a single GPU, we will make use of two packages + +1- [PEFT](https://huggingface.co/blog/peft) methods and in specific using HuggingFace [PEFT](https://github.com/huggingface/peft)library. + +2- [BitandBytes](https://github.com/TimDettmers/bitsandbytes) int8 quantization. + +Given combination of PEFT and Int8 quantization, we would be able to fine_tune a Llama 2 7B model on one consumer grade GPU such as A10. + +## Requirements +To run the examples, make sure to install the requirements using + +```bash + +pip install -r requirements.txt + +``` + +**Please note that the above requirements.txt will install PyTorch 2.0.1 version, in case you want to run FSDP + PEFT, please make sure to install PyTorch nightlies.** + +## How to run it? + +Get access to a machine with one GPU or if using a multi-GPU macine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id` and run the following. It runs by default with `samsum_dataset` for summarization application. + + +```bash + +python ../llama_finetuning.py --use_peft --peft_method lora --quantization --use_fp16 --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model + +``` +The args used in the command above are: + +* `--use_peft` boolean flag to enable PEFT methods in the script + +* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`. + +* `--quantization` boolean flag to enable int8 quantization + + +## How to run with different datasets? + +Currenty 4 datasets are supported that can be found in [Datasets config file](../configs/datasets.py). + +* `grammar_dataset` : use this [notebook](../ft_datasets/grammar_dataset/grammar_dataset_process.ipynb) to pull and process theJfleg and C4 200M datasets for grammar checking. + +* `alpaca_dataset` : to get this open source data please download the `aplaca.json` to `ft_dataset` folder. + +```bash +wget -P ft_dataset https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json +``` + +* `samsum_dataset` + +to run with each of the datasets set the `dataset` flag in the command as shown below: + +```bash +# grammer_dataset + +python ../llama_finetuning.py --use_peft --peft_method lora --quantization --dataset grammar_dataset --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model + +# alpaca_dataset + +python ../llama_finetuning.py --use_peft --peft_method lora --quantization --dataset alpaca_dataset --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model + + +# samsum_dataset + +python ../llama_finetuning.py --use_peft --peft_method lora --quantization --dataset samsum_dataset --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model + +``` + +## Where to configure settings? + +* [Training config file](../configs/training.py) is the main config file that help to specify the settings for our run can be found in + +It let us specify the training settings, everything from `model_name` to `dataset_name`, `batch_size` etc. can be set here. Below is the list of supported settings: + +```python + +model_name: str="PATH/to/LLAMA 2/7B" +enable_fsdp: bool= False +run_validation: bool=True +batch_size_training: int=4 +num_epochs: int=3 +num_workers_dataloader: int=2 +lr: float=2e-4 +weight_decay: float=0.0 +gamma: float= 0.85 +use_fp16: bool=False +mixed_precision: bool=True +val_batch_size: int=4 +dataset = "samsum_dataset" # alpaca_dataset,grammar_dataset +micro_batch_size: int=1 +peft_method: str = "lora" # None , llama_adapter, prefix +use_peft: bool=False +output_dir: str = "./ft-output" +freeze_layers: bool = False +num_freeze_layers: int = 1 +quantization: bool = False +one_gpu: bool = False +save_model: bool = False +dist_checkpoint_root_folder: str="model_checkpoints" +dist_checkpoint_folder: str="fine-tuned" +save_optimizer: bool=False + +``` + +* [Datasets config file](../configs/datasets.py) provides the avaiable options for datasets. + +* [peft config file](../configs/peft.py) provides the suported PEFT methods and respective settings that can be modified. diff --git a/ft_datasets/__init__.py b/ft_datasets/__init__.py new file mode 100644 index 000000000..1456dd9e2 --- /dev/null +++ b/ft_datasets/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .grammar_dataset import get_dataset as get_grammar_dataset +from .alpaca_dataset import InstructionDataset as get_alpaca_dataset +from .samsum_dataset import get_preprocessed_samsum as get_samsum_dataset \ No newline at end of file diff --git a/ft_datasets/alpaca_dataset.py b/ft_datasets/alpaca_dataset.py new file mode 100644 index 000000000..4d492460f --- /dev/null +++ b/ft_datasets/alpaca_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html + +import copy +import json +import os +import torch + +from sentencepiece import SentencePieceProcessor +from torch.utils.data import Dataset +from typing import List + +PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} + +class InstructionDataset(Dataset): + def __init__(self, dataset_config, tokenizer, partition="train", max_words=30): + self.ann = json.load(open(dataset_config.data_path)) + if partition == "train": + self.ann = self.ann + else: + self.ann = self.ann[:200] + + self.max_words = max_words + # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model") + self.tokenizer = tokenizer + # self.tokenizer1 = tokenizer + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + ann = self.ann[index] + if ann.get("input", "") == "": + prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) + else: + prompt = PROMPT_DICT["prompt_input"].format_map(ann) + example = prompt + ann["output"] + prompt = torch.tensor( + self.tokenizer.encode(prompt), dtype=torch.int64 + ) + example = self.tokenizer.encode(example) + example.append(self.tokenizer.eos_token_id) + example = torch.tensor( + example, dtype=torch.int64 + ) + padding = self.max_words - example.shape[0] + if padding > 0: + example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) + elif padding < 0: + example = example[: self.max_words] + labels = copy.deepcopy(example) + labels[: len(prompt)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = 0 + example_mask = example_mask.float() + label_mask = label_mask.float() + + return { + "input_ids": example, + "labels": labels, + "attention_mask":example_mask, + } diff --git a/ft_datasets/grammar_dataset/__init__.py b/ft_datasets/grammar_dataset/__init__.py new file mode 100644 index 000000000..0184a8037 --- /dev/null +++ b/ft_datasets/grammar_dataset/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .grammar_dataset import get_dataset diff --git a/ft_datasets/grammar_dataset/grammar_dataset.py b/ft_datasets/grammar_dataset/grammar_dataset.py new file mode 100644 index 000000000..cd2e74cd6 --- /dev/null +++ b/ft_datasets/grammar_dataset/grammar_dataset.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# For dataset details visit: https://huggingface.co/datasets/jfleg +# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb + +import argparse +import csv +import glob +import os +import json +import time +import logging +import random +import re +from itertools import chain +from string import punctuation + + +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset + +from datasets import load_dataset +from pathlib import Path + +from ft_datasets.utils import ConcatDataset + + + +class grammar(Dataset): + def __init__( + self, + tokenizer, + csv_name=None, + ): + + try: + self.dataset = load_dataset( + "csv", + data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"}, + delimiter=",", + ) + except Exception as e: + print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.") + raise e + + # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path) + # if num_samples: + # self.dataset = self.dataset.select(list(range(0, num_samples))) + self.tokenizer = tokenizer + self.print_text = False # print_text + + def __len__(self): + return self.dataset["train"].shape[0] + + def convert_to_features(self, example_batch): + + # Create prompt and tokenize contexts and questions + + if self.print_text: + print("Input Text: ", self.clean_text(example_batch["text"])) + + input_ = example_batch["input"] + target_ = example_batch["target"] + + prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}" + sample = self.tokenizer(prompt) + + return sample + + def __getitem__(self, index): + sample = self.convert_to_features(self.dataset["train"][index]) + source_ids = sample["input_ids"] + + src_mask = sample["attention_mask"] + + return { + "input_ids": source_ids, + "attention_mask": src_mask, + "labels": source_ids.copy(), + } + + +def get_dataset( + dataset_config, tokenizer, csv_name=None +): + """cover function for handling loading the working dataset""" + """dataset loading""" + if csv_name is None: + currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv" + print(f"Loading dataset {currPath}") + csv_name = str(currPath) + dataset = grammar( + tokenizer=tokenizer, + csv_name=csv_name, + ) + + return ConcatDataset(dataset, chunk_size=dataset_config.input_length) + diff --git a/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb b/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb new file mode 100644 index 000000000..ccbddca6c --- /dev/null +++ b/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb @@ -0,0 +1,463 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n", + "\n", + "Use this notebook to pull in datasets and apply pre-processing. Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + + "source": [ + "import csv\n", + "from datasets import load_metric, load_dataset\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "list_replacements = [\n", + " (\" .\", \".\"), \n", + " (\" ,\", \",\"),\n", + " (\" '\", \"'\"),\n", + " (\" ?\", \"?\"),\n", + " (\" !\", \"!\"),\n", + " (\" :\", \"!\"),\n", + " (\" ;\", \"!\"),\n", + " (\" n't\", \"n't\"),\n", + " (\" v\", \"n't\"),\n", + " (\"2 0 0 6\", \"2006\"),\n", + " (\"5 5\", \"55\"),\n", + " (\"4 0 0\", \"400\"),\n", + " (\"1 7-5 0\", \"1750\"),\n", + " (\"2 0 %\", \"20%\"),\n", + " (\"5 0\", \"50\"),\n", + " (\"1 2\", \"12\"),\n", + " (\"1 0\", \"10\"),\n", + " ('\" ballast water', '\"ballast water')\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def correct_spacing(item):\n", + " \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n", + " for fix in list_replacements:\n", + " item = item.replace(fix[0], fix[1])\n", + " return item\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_csv(csv_path, dataset):\n", + " \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n", + " with open(csv_path, 'w', newline='') as csvfile:\n", + " writer = csv.writer(csvfile)\n", + " writer.writerow([\"input\", \"target\"])\n", + " for case in dataset:\n", + " \t # Adding the t5 task indication prefix to input \n", + + " input_text = case[\"sentence\"]\n", + + " input_text = correct_spacing(input_text)\n", + "\n", + " for correction in case[\"corrections\"]:\n", + " correction = correct_spacing(correction)\n", + " # a few of the cases contain blank strings. \n", + " if input_text and correction:\n", + " writer.writerow([input_text, correction])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In Jfleg - validation will be used as 'train', test will be 'validation'" + ] + }, + { + "cell_type": "code", + + "execution_count": 5, + + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + + "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n", + "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n" + + ] + } + ], + "source": [ + "train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n", + "eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n" + ] + }, + { + "cell_type": "code", + + "execution_count": 6, + + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset({\n", + " features: ['sentence', 'corrections'],\n", + " num_rows: 755\n", + "})\n", + "Dataset({\n", + " features: ['sentence', 'corrections'],\n", + " num_rows: 748\n", + "})\n" + ] + } + ], + "source": [ + "print(train_dataset)\n", + "print(eval_dataset)\n" + ] + }, + { + "cell_type": "code", + + "execution_count": 7, + + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n", + "['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n" + ] + } + ], + "source": [ + "print(train_dataset['sentence'][22])\n", + "print(train_dataset['corrections'][22])" + ] + }, + { + "cell_type": "code", + + "execution_count": 8, + + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '" + ] + }, + + "execution_count": 8, + + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clean22 = correct_spacing(train_dataset['sentence'][22])\n", + "clean22" + ] + }, + { + "cell_type": "code", + + "execution_count": 9, + + "metadata": {}, + "outputs": [], + "source": [ + "jfleg_dir = Path.cwd()/'jfleg_dataset' # if you only use 'jfleg', hf will try and use that and complain\n", + "jfleg_dir.mkdir(parents=True,exist_ok=True)\n", + "c4_dir = Path.cwd()/'c4_dataset'\n", + "c4_dir.mkdir(parents=True,exist_ok=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Process Jfleg data " + ] + }, + { + "cell_type": "code", + + "execution_count": 10, + + "metadata": {}, + "outputs": [], + "source": [ + "j_train_file = jfleg_dir/'jtrain.csv'\n", + "j_eval_file = jfleg_dir/'jeval.csv'" + ] + }, + { + "cell_type": "code", + + "execution_count": 11, + + "metadata": {}, + "outputs": [], + "source": [ + "generate_csv(j_train_file, train_dataset)" + ] + }, + { + "cell_type": "code", + + "execution_count": 12, + + "metadata": {}, + "outputs": [], + "source": [ + "generate_csv(j_eval_file, eval_dataset)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Process C4_200M (!) - we'll pull 10K to start" + ] + }, + { + "cell_type": "code", + + "execution_count": 13, + + "metadata": {}, + "outputs": [], + "source": [ + "c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)" + ] + }, + { + "cell_type": "code", + + "execution_count": 14, + + "metadata": {}, + "outputs": [], + "source": [ + "iterator = iter(c4_dataset['train'])" + ] + }, + { + "cell_type": "code", + + "execution_count": 15, + + "metadata": {}, + "outputs": [], + "source": [ + "def c4_generate_csv(csv_path, iterator, num_examples):\n", + " with open(csv_path, 'w', newline='') as csvfile:\n", + " writer = csv.writer(csvfile)\n", + " writer.writerow([\"input\", \"target\"])\n", + " for i in range(0,num_examples):\n", + " data = next(iterator)\n", + + " input_text = data[\"input\"]\n", + + " input_text = correct_spacing(input_text)\n", + " correction = correct_spacing(data[\"output\"])\n", + " if input_text and correction:\n", + " writer.writerow([input_text, correction])" + ] + }, + { + "cell_type": "code", + + "execution_count": 16, + + "metadata": {}, + "outputs": [], + "source": [ + "c4_dir = Path.cwd()/'c4_dataset'\n", + "c4_dir.mkdir(parents=True,exist_ok=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test" + ] + }, + { + "cell_type": "code", + + "execution_count": 17, + + "metadata": {}, + "outputs": [], + "source": [ + "c4_filename = c4_dir/'c4train_10k.csv'" + ] + }, + { + "cell_type": "code", + + "execution_count": 18, + + "metadata": {}, + "outputs": [], + "source": [ + "c4_generate_csv(c4_filename, iterator, num_examples=10000)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a single training file by combining jtrain and c4train" + ] + }, + { + "cell_type": "code", + + "execution_count": 19, + + "metadata": {}, + "outputs": [], + "source": [ + "merge_list = [j_train_file, c4_filename, ]" + ] + }, + { + "cell_type": "code", + + "execution_count": 20, + + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + + "execution_count": 21, + + "metadata": {}, + "outputs": [], + "source": [ + "combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n" + ] + }, + { + "cell_type": "code", + + "execution_count": 22, + + "metadata": {}, + "outputs": [], + "source": [ + "merged_name = \"gtrain_10k.csv\"" + ] + }, + { + "cell_type": "code", + + "execution_count": 23, + + "metadata": {}, + "outputs": [], + "source": [ + "combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )" + ] + }, + { + "cell_type": "code", + + "execution_count": 24, + + "metadata": {}, + "outputs": [], + "source": [ + "eval_name = \"grammar_validation.csv\"" + ] + + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "eval_csv = pd.read_csv(j_eval_file)\n", + "eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )" + ] + + } + ], + "metadata": { + "interpreter": { + "hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac" + }, + "kernelspec": { + + "display_name": "Python 3 (ipykernel)", + + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + + } + }, + "nbformat": 4, + "nbformat_minor": 4 + +} diff --git a/ft_datasets/samsum_dataset.py b/ft_datasets/samsum_dataset.py new file mode 100644 index 000000000..a178e06d9 --- /dev/null +++ b/ft_datasets/samsum_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# For dataset details visit: https://huggingface.co/datasets/samsum + +import datasets +from .utils import Concatenator + +def get_preprocessed_samsum(dataset_config, tokenizer, split): + dataset = datasets.load_dataset("samsum", split=split) + + prompt = ( + f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}" + ) + + def apply_prompt_template(sample): + return { + "text": prompt.format( + dialog=sample["dialogue"], + summary=sample["summary"], + eos_token=tokenizer.eos_token, + ) + } + + dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) + + dataset = dataset.map( + lambda sample: tokenizer(sample["text"]), + batched=True, + remove_columns=list(dataset.features), + ).map(Concatenator(), batched=True) + return dataset diff --git a/ft_datasets/utils.py b/ft_datasets/utils.py new file mode 100644 index 000000000..3263d806a --- /dev/null +++ b/ft_datasets/utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from tqdm import tqdm +from itertools import chain +from torch.utils.data import Dataset + +class Concatenator(object): + def __init__(self, chunk_size=2048): + self.chunk_size=chunk_size + self.residual = {"input_ids": [], "attention_mask": []} + + def __call__(self, batch): + concatenated_samples = { + k: v + list(chain(*batch[k])) for k, v in self.residual.items() + } + + total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) + + if total_length >= self.chunk_size: + chunk_num = total_length // self.chunk_size + result = { + k: [ + v[i : i + self.chunk_size] + for i in range(0, chunk_num * self.chunk_size, self.chunk_size) + ] + for k, v in concatenated_samples.items() + } + self.residual = { + k: v[(chunk_num * self.chunk_size) :] + for k, v in concatenated_samples.items() + } + else: + result = concatenated_samples + self.residual = {k: [] for k in concatenated_samples.keys()} + + result["labels"] = result["input_ids"].copy() + + return result + +class ConcatDataset(Dataset): + def __init__(self, dataset, chunk_size=4096): + self.dataset = dataset + self.chunk_size = chunk_size + + self.samples = [] + + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + + for sample in tqdm(self.dataset, desc="Preprocessing dataset"): + buffer = {k: v + sample[k] for k,v in buffer.items()} + + while len(next(iter(buffer.values()))) > self.chunk_size: + self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) + buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) \ No newline at end of file diff --git a/inference/chat_completion.py b/inference/chat_completion.py new file mode 100644 index 000000000..274f381f9 --- /dev/null +++ b/inference/chat_completion.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# from accelerate import init_empty_weights, load_checkpoint_and_dispatch +import fire +import torch +import os +import sys +import warnings +from typing import List + +from peft import PeftModel, PeftConfig +from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM +from safety_utils import get_safety_checker +from model_utils import load_model, load_peft_model +from chat_utils import read_dialogs_from_file, format_tokens + +def main( + model_name, + peft_model: str=None, + quantization: bool=False, + max_new_tokens =256, #The maximum numbers of tokens to generate + min_new_tokens:int=0, #The minimum numbers of tokens to generate + prompt_file: str=None, + seed: int=42, #seed value for reproducibility + safety_score_threshold: float=0.5, + do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise. + use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float=1.0, # [optional] The value used to modulate the next token probabilities. + top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api + enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs + enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + **kwargs +): + if prompt_file is not None: + assert os.path.exists( + prompt_file + ), f"Provided Prompt file does not exist {prompt_file}" + + dialogs= read_dialogs_from_file(prompt_file) + + elif not sys.stdin.isatty(): + dialogs = "\n".join(sys.stdin.readlines()) + else: + print("No user prompt provided. Exiting.") + sys.exit(1) + + print(f"User dialogs:\n{dialogs}") + print("\n==================================\n") + + + # Set the seeds for reproducibility + torch.cuda.manual_seed(seed) + torch.manual_seed(seed) + model = load_model(model_name, quantization) + if peft_model: + model = load_peft_model(model, peft_model) + tokenizer = LlamaTokenizer.from_pretrained(model_name) + tokenizer.add_special_tokens( + { + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": "[PAD]", + } + ) + + chats = format_tokens(dialogs, tokenizer) + + with torch.no_grad(): + for idx, chat in enumerate(chats): + safety_checker = get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_saleforce_content_safety, + ) + # Safety check of the user prompt + safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print(f"User prompt deemed safe.") + print("User prompt:\n", dialogs[idx][0]["content"]) + print("\n==================================\n") + else: + print("User prompt deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + print("Skipping the inferece as the prompt is not safe.") + sys.exit(1) # Exit the program with an error status + tokens= torch.tensor(chat).long() + tokens= tokens.unsqueeze(0) + tokens= tokens.to("cuda:0") + outputs = model.generate( + tokens, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_p=top_p, + temperature=temperature, + use_cache=use_cache, + top_k=top_k, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + **kwargs + ) + + output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Safety check of the model output + safety_results = [check(output_text) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User input and model output deemed safe.") + print(f"Model output:\n{output_text}") + print("\n==================================\n") + + else: + print("Model output deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/inference/chat_utils.py b/inference/chat_utils.py new file mode 100644 index 000000000..c8c90582b --- /dev/null +++ b/inference/chat_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from typing import List, Literal, Optional, Tuple, TypedDict, Union +import json + +Role = Literal["user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" +DEFAULT_SYSTEM_PROMPT = """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + +def format_tokens(dialogs, tokenizer): + prompt_tokens = [] + for dialog in dialogs: + if dialog[0]["role"] != "system": + dialog = [ + { + "role": "system", + "content": DEFAULT_SYSTEM_PROMPT, + } + ] + dialog + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system','user' and 'assistant' roles, " + "starting with user and alternating (u/a/u/a/u...)" + ) + """ + Please verify that yout tokenizer support adding "[INST]", "[/INST]" to your inputs. + Here, we are adding it manually. + """ + dialog_tokens: List[int] = sum( + [ + tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + ) + for prompt, answer in zip(dialog[::2], dialog[1::2]) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + ) + prompt_tokens.append(dialog_tokens) + return prompt_tokens + + +def read_dialogs_from_file(file_path): + with open(file_path, 'r') as file: + dialogs = json.load(file) + return dialogs \ No newline at end of file diff --git a/inference/chats.json b/inference/chats.json new file mode 100644 index 000000000..4b1021bac --- /dev/null +++ b/inference/chats.json @@ -0,0 +1,22 @@ +[ + [{"role": "user", "content": "what is the recipe of mayonnaise?"}], + [ + {"role": "user", "content": "I am going to Paris, what should I see?"}, + { + "role": "assistant", + "content": "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. 2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. 3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world." + }, + {"role": "user", "content": "What is so great about #1?"} + ], + [ + {"role": "system", "content": "Always answer with Haiku"}, + {"role": "user", "content": "I am going to Paris, what should I see?"} + ], + [ + { + "role": "system", + "content": "Always answer with emojis" + }, + {"role": "user", "content": "How to go from Beijing to NY?"} + ] +] \ No newline at end of file diff --git a/inference/hf-text-generation-inference/README.md b/inference/hf-text-generation-inference/README.md new file mode 100644 index 000000000..64779d624 --- /dev/null +++ b/inference/hf-text-generation-inference/README.md @@ -0,0 +1,48 @@ +# Serving a fine tuned LLaMA model with HuggingFace text-generation-inference server + +This document shows how to serve a fine tuned LLaMA mode with HuggingFace's text-generation-inference server. This option is currently only available for models that were trained using the LoRA method or without using the `--use_peft` argument. + +## Step 0: Merging the weights (Only required if LoRA method was used) + +In case the model was fine tuned with LoRA mehtod we need to merge the weights of the base model with the adapter weight. For this we can use the script `merge_lora_weights.py` which is located in the same folder as this README file. + +The script takes the base model, the peft weight folder as well as an output as arguments: + +``` +python inference/hf-text-generation-inference/merge_lora_weights.py --base_model llama-7B --peft_model ft_output --output_dir data/merged_model_output +``` + +## Step 1: Serving the model +Subsequently, the model can be served using the docker container provided by [hf text-generation-inference](https://github.com/huggingface/text-generation-inference) started from the main directory of this repository: + +```bash +model=/data/merged_model_output +num_shard=2 +volume=$PWD/inference/hf-text-generation-inference/data +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard +``` + +The num_shard argument determines the number of GPU's the model should be sharded on. + +## Step 2: Running inference +After the loading of the model shards completed an inference can be executed by using one of the following commands: + +```bash +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \ + -H 'Content-Type: application/json' +# OR for streaming inference +curl 127.0.0.1:8080/generate_stream \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \ + -H 'Content-Type: application/json' +``` + +Further information can be found in the documentation of the [hf text-generation-inference](https://github.com/huggingface/text-generation-inference) solution. + + + + + + diff --git a/inference/hf-text-generation-inference/merge_lora_weights.py b/inference/hf-text-generation-inference/merge_lora_weights.py new file mode 100644 index 000000000..b5137f1c6 --- /dev/null +++ b/inference/hf-text-generation-inference/merge_lora_weights.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import fire +import torch +from peft import PeftModel +from transformers import LlamaForCausalLM, LlamaTokenizer + + +def main(base_model: str, + peft_model: str, + output_dir: str): + + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=False, + torch_dtype=torch.float16, + device_map="auto", + offload_folder="tmp", + ) + + tokenizer = LlamaTokenizer.from_pretrained( + base_model + ) + + model = PeftModel.from_pretrained( + model, + peft_model, + torch_dtype=torch.float16, + device_map="auto", + offload_folder="tmp", + ) + + model = model.merge_and_unload() + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/inference/inference.md b/inference/inference.md new file mode 100644 index 000000000..9c8586844 --- /dev/null +++ b/inference/inference.md @@ -0,0 +1,46 @@ +# Inference + +For inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments. +To finetune all model parameters the output dir of the training has to be given as --model_name argument. +In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument. +Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter. + +**Content Safety** +The inference script also supports safety checks for both user prompt and model outputs. In particular, we use two packages, [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/). + +**Note** +If using Azure content Safety, please make sure to get the endpoint and API key as described [here](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/) and add them as the following environment variables,`CONTENT_SAFETY_ENDPOINT` and `CONTENT_SAFETY_KEY`. + +Examples: + + ```bash +# Full finetuning of all parameters +cat | python inference/inference.py --model_name --use_auditnlg +# PEFT method +cat | python inference/inference.py --model_name --peft_model --use_auditnlg +# prompt as parameter +python inference/inference.py --model_name --prompt_file --use_auditnlg + ``` +The inference folder contains test prompts for summarization use-case: +``` +inference/samsum_prompt.txt +... +``` + +**Chat completion** +The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example: + +```bash +python chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chats.json --quantization --use_auditnlg + +``` + +## Other Inference Options + +Alternate inference options include: + +[**vLLM**](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html): +To use vLLM you will need to install it using the instructions [here](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#installation). +Once installed, you can use the vLLM_ineference.py script provided [here](vLLM_inference.py). + +[**TGI**](https://github.com/huggingface/text-generation-inference): Text Generation Inference (TGI) is another inference option available to you. For more information on how to set up and use TGI see [here](https://github.com/huggingface/text-generation-inference). diff --git a/inference/inference.py b/inference/inference.py new file mode 100644 index 000000000..b6398d407 --- /dev/null +++ b/inference/inference.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# from accelerate import init_empty_weights, load_checkpoint_and_dispatch + +import fire +import torch +import os +import sys +from typing import List + +from transformers import LlamaTokenizer +from safety_utils import get_safety_checker +from model_utils import load_model, load_peft_model + + +def main( + model_name, + peft_model: str=None, + quantization: bool=False, + max_new_tokens =100, #The maximum numbers of tokens to generate + prompt_file: str=None, + seed: int=42, #seed value for reproducibility + do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise. + min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens + use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float=1.0, # [optional] The value used to modulate the next token probabilities. + top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api + enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs + enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + **kwargs +): + if prompt_file is not None: + assert os.path.exists( + prompt_file + ), f"Provided Prompt file does not exist {prompt_file}" + with open(prompt_file, "r") as f: + user_prompt = "\n".join(f.readlines()) + elif not sys.stdin.isatty(): + user_prompt = "\n".join(sys.stdin.readlines()) + else: + print("No user prompt provided. Exiting.") + sys.exit(1) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(seed) + torch.manual_seed(seed) + model = load_model(model_name, quantization) + + tokenizer = LlamaTokenizer.from_pretrained(model_name) + tokenizer.add_special_tokens( + { + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": "[PAD]", + } + ) + + safety_checker = get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_saleforce_content_safety, + ) + + # Safety check of the user prompt + safety_results = [check(user_prompt) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User prompt deemed safe.") + print(f"User prompt:\n{user_prompt}") + else: + print("User prompt deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + print("Skipping the inferece as the prompt is not safe.") + sys.exit(1) # Exit the program with an error status + + if peft_model: + model = load_peft_model(model, peft_model) + + model.eval() + + batch = tokenizer(user_prompt, return_tensors="pt") + batch = {k: v.to("cuda") for k, v in batch.items()} + + with torch.no_grad(): + outputs = model.generate( + **batch, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_p=top_p, + temperature=temperature, + min_length=min_length, + use_cache=use_cache, + top_k=top_k, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + **kwargs + ) + + output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Safety check of the model output + safety_results = [check(output_text) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User input and model output deemed safe.") + print(f"Model output:\n{output_text}") + else: + print("Model output deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/inference/model_utils.py b/inference/model_utils.py new file mode 100644 index 000000000..815757fd4 --- /dev/null +++ b/inference/model_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from peft import PeftModel +from transformers import LlamaForCausalLM + +# Function to load the main model for text generation +def load_model(model_name, quantization): + model = LlamaForCausalLM.from_pretrained( + model_name, + return_dict=True, + load_in_8bit=quantization, + device_map="auto", + low_cpu_mem_usage=True, + ) + return model + + +# Function to load the PeftModel for performance optimization +def load_peft_model(model, peft_model): + peft_model = PeftModel.from_pretrained(model, peft_model) + return peft_model \ No newline at end of file diff --git a/inference/safety_utils.py b/inference/safety_utils.py new file mode 100644 index 000000000..9c6d0c361 --- /dev/null +++ b/inference/safety_utils.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import torch +import warnings + +from peft import PeftConfig +from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM + +# Class for performing safety checks using AuditNLG library +class AuditNLGSensitiveTopics(object): + def __init__(self): + pass + + def __call__(self, output_text): + try: + from auditnlg.safety.exam import safety_scores + except ImportError as e: + print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt") + raise e + + + data = [{"output": output_text}] + + result = safety_scores(data=data, method="sensitive_topics") + scores = result[1]["all_scores"][0] + is_safe = scores["pred_class"] == "none" + report = "" + if not is_safe: + report += f"Predicted class: {scores['pred_class']}\n" + report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n" + report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n" + return "Sensitive Topics", is_safe, report + + +class SalesforceSafetyChecker(object): + def __init__(self): + pass + + def __call__(self, output_text): + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig + + config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base") + tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config) + model.eval() + + prefix = "Is the field safe or unsafe?" + input_ids = tokenizer(prefix + " " + output_text + " ", return_tensors="pt").input_ids + + if len(input_ids[0]) > 512: + warnings.warn( + "Input length is > 512 token. Safety check result could be incorrect." + ) + + with torch.no_grad(): + outputs = model.generate( + input_ids, + output_scores=True, + return_dict_in_generate=True, + max_new_tokens=20, + ) + + is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe" + + report = "" + if not is_safe: + true_false_ids = tokenizer("true false").input_ids[:2] + keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"] + scores = {} + for k, i in zip(keys, range(3,20,2)): + scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5) + + report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n" + report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n" + return "Salesforce Content Safety Flan T5 Base", is_safe, report + + + def get_total_length(self, data): + prefix = "Is the field safe or unsafe " + input_sample = " {output} ".format(**data[0]) + + return len(self.tokenizer(prefix + input_sample)["input_ids"]) + + +# Class for performing safety checks using Azure Content Safety service +class AzureSaftyChecker(object): + def __init__(self): + try: + from azure.ai.contentsafety import ContentSafetyClient + from azure.core.credentials import AzureKeyCredential + + key = os.environ["CONTENT_SAFETY_KEY"] + endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"] + except ImportError: + raise Exception( + "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety" + ) + except KeyError: + raise Exception( + "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT." + ) + + self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key)) + + def __call__(self, output_text): + from azure.core.exceptions import HttpResponseError + from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory + + print(len(output_text)) + if len(output_text) > 1000: + raise Exception("Input length to safety check is too long (>1000).") + + categories = [ + TextCategory.VIOLENCE, + TextCategory.SELF_HARM, + TextCategory.SEXUAL, + TextCategory.HATE, + ] + + request = AnalyzeTextOptions(text=output_text, categories=categories) + + try: + response = self.client.analyze_text(request) + except HttpResponseError as e: + print("Analyze text failed.") + if e.error: + print(f"Error code: {e.error.code}") + print(f"Error message: {e.error.message}") + raise + print(e) + raise e + + levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"} + + severities = [ + getattr(response, c.name.lower() + "_result").severity for c in categories + ] + + DEFAULT_LEVELS = [0, 0, 0, 0] + + is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)]) + + report = "" + if not is_safe: + report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n" + report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n" + + return "Azure Content Saftey API", is_safe, report + + +# Function to load the PeftModel for performance optimization +# Function to determine which safety checker to use based on the options selected +def get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_saleforce_content_safety, + ): + safety_checker = [] + if enable_azure_content_safety: + safety_checker.append(AzureSaftyChecker()) + if enable_sensitive_topics: + safety_checker.append(AuditNLGSensitiveTopics()) + if enable_saleforce_content_safety: + safety_checker.append(SalesforceSafetyChecker()) + return safety_checker + + + + + diff --git a/inference/samsum_prompt.txt b/inference/samsum_prompt.txt new file mode 100644 index 000000000..a1c0a40d6 --- /dev/null +++ b/inference/samsum_prompt.txt @@ -0,0 +1,20 @@ +Summarize this dialog: +A: Hi Tom, are you busy tomorrow’s afternoon? +B: I’m pretty sure I am. What’s up? +A: Can you go with me to the animal shelter?. +B: What do you want to do? +A: I want to get a puppy for my son. +B: That will make him so happy. +A: Yeah, we’ve discussed it many times. I think he’s ready now. +B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) +A: I'll get him one of those little dogs. +B: One that won't grow up too big;-) +A: And eat too much;-)) +B: Do you know which one he would like? +A: Oh, yes, I took him there last Monday. He showed me one that he really liked. +B: I bet you had to drag him away. +A: He wanted to take it home right away ;-). +B: I wonder what he'll name it. +A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-))) +--- +Summary: \ No newline at end of file diff --git a/inference/vLLM_inference.py b/inference/vLLM_inference.py new file mode 100644 index 000000000..63c644148 --- /dev/null +++ b/inference/vLLM_inference.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +import fire +import torch +import os +import sys +from peft import PeftModel, PeftConfig +from transformers import ( + LlamaConfig, + LlamaTokenizer, + LlamaForCausalLM +) +from vllm import LLM +from vllm import LLM, SamplingParams + +torch.cuda.manual_seed(42) +torch.manual_seed(42) + +def load_model(model_name, tp_size=1): + + llm = LLM(model_name, tensor_parallel_size=tp_size) + return llm + +def main( + model, + max_new_tokens=100, + user_prompt=None, + top_p=0.9, + temperature=0.8 +): + while True: + if user_prompt is None: + user_prompt = input("Enter your prompt: ") + + print(f"User prompt:\n{user_prompt}") + + print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request") + sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens) + + + outputs = model.generate(user_prompt, sampling_params=sampling_param) + + print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}") + user_prompt = input("Enter next prompt (press Enter to exit): ") + if not user_prompt: + break + +def run_script( + model_name: str, + peft_model=None, + tp_size=1, + max_new_tokens=100, + user_prompt=None, + top_p=0.9, + temperature=0.8 +): + model = load_model(model_name, tp_size) + main(model, max_new_tokens, user_prompt, top_p, temperature) + +if __name__ == "__main__": + fire.Fire(run_script) diff --git a/llama_finetuning.py b/llama_finetuning.py new file mode 100644 index 000000000..85bf18e3b --- /dev/null +++ b/llama_finetuning.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import sys +from typing import List, Union + +import fire +import torch +import transformers +from datasets import load_dataset +import os.path as osp +from tqdm import tqdm + +# Unused imports removed +from utils import fsdp_auto_wrap_policy +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + default_data_collator, + BitsAndBytesConfig +) +import torch.distributed as dist + +# Unused imports removed +from utils.train_utils import ( + set_tokenizer_params, + train, + evaluation, + freeze_transformer_layers, + check_frozen_layers_peft_model, + setup, + setup_environ_flags, + cleanup, + clear_gpu_cache, + get_parameter_dtypes, + print_model_size, + get_policies +) + +from utils.dataset_utils import get_preprocessed_dataset + +from utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, +) +from peft import get_peft_model, TaskType, prepare_model_for_int8_training +import configs +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + MixedPrecision, + StateDictType, +) +from torch.utils.data import DistributedSampler +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened +import policies +from policies import AnyPrecisionAdamW +from configs import fsdp_config, train_config +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from pkg_resources import packaging +import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist +from transformers.models.t5.modeling_t5 import T5Block +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + + +def main(**kwargs): + # Update the configuration for the training and sharding process + update_config((train_config, fsdp_config), **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + + if train_config.enable_fsdp: + setup() + # torchrun specific + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if torch.distributed.is_initialized(): + torch.cuda.set_device(rank) + setup_environ_flags(rank) + + # Calculate gradient accumulation steps + gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size + + # Load the pre-trained model and setup its configuration + model = LlamaForCausalLM.from_pretrained( + train_config.model_name, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + ) + + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + + # Prepare the model for int8 training if quantization is enabled + if train_config.quantization: + model = prepare_model_for_int8_training(model) + + # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled + if train_config.enable_fsdp and fsdp_config.pure_bf16: + model.to(torch.bfloat16) + + # Load the tokenizer and add special tokens + tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) + tokenizer.add_special_tokens( + { + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": '[PAD]', + } + ) + if train_config.use_peft: + peft_config = generate_peft_config(train_config, kwargs) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + #setting up FSDP if enable_fsdp is enabled + if train_config.enable_fsdp: + if not train_config.use_peft and train_config.freeze_layers: + + freeze_transformer_layers(train_config.num_freeze_layers) + + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) + my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) + + model = FSDP( + model, + auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, + mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, + sharding_strategy=fsdp_config.sharding_strategy, + device_id=torch.cuda.current_device(), + limit_all_gathers=False, + ) + if fsdp_config.fsdp_activation_checkpointing: + policies.apply_fsdp_checkpointing(model) + elif not train_config.quantization and not train_config.enable_fsdp: + model.to("cuda") + + dataset_config = generate_dataset_config(train_config, kwargs) + + # Load and preprocess the dataset for training and validation + dataset_train = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="train", + ) + + if not train_config.enable_fsdp or rank == 0: + print(f"--> Training Set Length = {len(dataset_train)}") + + dataset_val = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="test", + ) + if not train_config.enable_fsdp or rank == 0: + print(f"--> Validation Set Length = {len(dataset_val)}") + + train_sampler = None + val_sampler = None + if train_config.enable_fsdp: + train_sampler = DistributedSampler( + dataset_train, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=True, + ) + if train_config.run_validation: + val_sampler = DistributedSampler( + dataset_val, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + + # Create DataLoaders for the training and validation dataset + train_dataloader = torch.utils.data.DataLoader( + dataset_train, + batch_size=train_config.batch_size_training, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + sampler=train_sampler if train_sampler else None, + drop_last=True, + collate_fn=default_data_collator, + ) + + if train_config.run_validation: + eval_dataloader = torch.utils.data.DataLoader( + dataset_val, + batch_size=train_config.val_batch_size, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + sampler=val_sampler if val_sampler else None, + drop_last=True, + collate_fn=default_data_collator, + ) + + # Initialize the optimizer and learning rate scheduler + if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": + optimizer = AnyPrecisionAdamW( + model.parameters(), + lr=train_config.lr, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + use_kahan_summation=False, + ) + else: + optimizer = optim.AdamW( + model.parameters(), + lr=train_config.lr, + weight_decay=0.0, + ) + scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + + # Start the training process + results = train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + scheduler, + gradient_accumulation_steps, + train_config, + fsdp_config if train_config.enable_fsdp else None, + local_rank if train_config.enable_fsdp else None, + rank if train_config.enable_fsdp else None, + ) + if not train_config.enable_fsdp or rank==0: + [print(f'Key: {k}, Value: {v}') for k, v in results.items()] + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/model_checkpointing/__init__.py b/model_checkpointing/__init__.py new file mode 100644 index 000000000..d9946f413 --- /dev/null +++ b/model_checkpointing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .checkpoint_handler import ( + load_model_checkpoint, + save_model_checkpoint, + save_distributed_model_checkpoint, + load_distributed_model_checkpoint, + load_optimizer_checkpoint, + save_optimizer_checkpoint, + save_model_and_optimizer_sharded, + load_model_sharded, +) diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py new file mode 100644 index 000000000..e917c7f2d --- /dev/null +++ b/model_checkpointing/checkpoint_handler.py @@ -0,0 +1,306 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from pathlib import Path +from datetime import datetime +import torch +import time + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, # general model non-sharded, non-flattened params + LocalStateDictConfig, # flattened params, usable only by FSDP + # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. +) + +from torch.distributed._shard.checkpoint import ( + FileSystemReader, + FileSystemWriter, + save_state_dict, + load_state_dict, +) +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) + + +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +import torch.distributed._shard.checkpoint as dist_cp +import torch.distributed as dist + + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + +# create singleton saving policies to avoid making over and over +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + +def load_model_sharded(model, rank, cfg, verbose=True): + # torch.manual_seed(103) + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + load_dir = Path.cwd() / folder_name + + if not load_dir.exists(): + if rank == 0: + print(f"No sharded_state_dict checkpoint directory found...skipping") + return + if rank == 0: + print(f"loading model from model path: {load_dir} ") + reader = FileSystemReader(load_dir) + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + checkpoint = model.state_dict() + if rank == 0: + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + dist_cp.load_state_dict( + state_dict=checkpoint, + storage_reader=reader, + ) + if rank == 0: + print(f"checkpoint after load_state_dict()") + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + model.load_state_dict(checkpoint) + if rank == 0: + print(f"Sharded state checkpoint loaded from {load_dir}") + + +def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True): + """save model and optimizer via sharded_state_dict to save_dir""" + + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + save_dir = Path.cwd() / folder_name + if rank == 0: + print(f"Saving model to {save_dir}") + + distributed_writer = dist_cp.FileSystemWriter( + save_dir, + ) + t0 = time.perf_counter() + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + + state_dict = {"model": model.state_dict()} + if optim is not None: + state_dict["optim"] = FSDP.optim_state_dict(model, optim) + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=distributed_writer, + planner=DefaultSavePlanner(), + + ) + dist.barrier() + t1 = time.perf_counter() + if rank == 0: + print(f"Sharded state checkpoint saved to {save_dir}") + print( + f"Checkpoint Time = {t1-t0:.4f}\n" + ) +def save_model_checkpoint( + model, + optimizer, + rank, + cfg, + epoch=1, +): + """saving model via rank0 cpu streaming and full_state_dict""" + + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, fullstate_save_policy + ): + cpu_state = model.state_dict() + + print(f"saving process: rank {rank} done w model state_dict\n") + + + if rank == 0: + print(f"--> saving model ...") + # create save path + save_dir = Path.cwd() / cfg.checkpoint_folder + save_dir.mkdir(parents=True, exist_ok=True) + save_name = cfg.model_name + "-" + str(epoch) + ".pt" + save_full_path = str(save_dir) + "/" + save_name + + # save model + torch.save(cpu_state, save_full_path) + + if cfg.verbose: + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + + + +def load_model_checkpoint(model, rank, cfg, verbose=True): + """load local checkpoint to rank0 cpu + must be called * before * passing to FSDP""" + + if rank != 0: + return + + # where is the checkpoint at... + full_state_dict_model_path = ( + Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename + ) + # is it present... + if not full_state_dict_model_path.is_file(): + print( + f"model checkpoint {full_state_dict_model_path} not present. Returning..." + ) + return + + + model_checkpoint = torch.load(full_state_dict_model_path) + # integrate into loaded model + model.load_state_dict(model_checkpoint) + + if cfg.verbose: + print(f"model checkpoint loaded to rank0 cpu") + + +def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): + """save optimizer state via full state dict""" + + + print(f"--> optim state call on rank {rank}\n") + + # pull all sharded optimizer states to rank0 cpu... + + optim_state = FSDP.full_optim_state_dict(model, optimizer) + + if cfg.verbose: + print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + + if rank == 0: + save_dir = Path.cwd() / cfg.checkpoint_folder + save_dir.mkdir(parents=True, exist_ok=True) + + opt_save_name = ( + cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt" + ) + opt_save_full_path = save_dir / opt_save_name + + print(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + print(f"--> saved {opt_save_full_path} to disk") + + +def load_optimizer_checkpoint(model, optimizer, rank, cfg): + """load an fdsp optimizer full_state checkpoint using scatter method + this ensures only rank 0 loads the optimizer state dict and scatters to other ranks + """ + + opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file + + if not opt_file_path.is_file(): + print( + f"warning - optimizer checkpoint not present {opt_file_path}. Returning. " + ) + return + + full_osd = None + + if rank == 0: + full_osd = torch.load(opt_file_path) + + if cfg.verbose: + print(f"loaded full osd on rank 0") + + # called from all ranks, though only rank0 has a valid param for full_osd + sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) + + if cfg.verbose: + print(f"optimizer shard loaded on rank {rank}") + + + +def load_distributed_model_checkpoint(model, rank, cfg): + if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: + print(f"loading distributed checkpoint, rank {rank}...") + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + checkdir = Path.cwd() / folder_name + + if not checkdir.exists(): + if rank == 0: + print(f"No checkpoint directory found...skipping") + return + + + reader = FileSystemReader(checkdir) + + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + ): + state_dict = model.state_dict() + load_state_dict(state_dict, reader) + model.load_state_dict(state_dict) + + print(f"--> local state loaded on rank {rank}") + + return + + +def save_distributed_model_checkpoint(model, rank, cfg, epoch=1): + # distributed checkpoint saving + + # confirm type of checkpoint and save + if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: + # create writer to current path + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name + + writer = FileSystemWriter( + save_dir, + ) + + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + ): + state_dict = model.state_dict() + + + # write out distributed checkpoint + save_state_dict(state_dict, writer) + + return diff --git a/multi_node.slurm b/multi_node.slurm new file mode 100644 index 000000000..67580feac --- /dev/null +++ b/multi_node.slurm @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + + +#!/bin/bash + +#SBATCH --job-name=Nano-2d-trainer-20b-8nodes + +#SBATCH --ntasks=2 +#SBATCH --nodes=2 +#SBATCH --gpus-per-task=4 +#SBATCH --partition=train +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +# Enable for A100 +export FI_PROVIDER="efa" + +echo Node IP: $head_node_ip +export LOGLEVEL=INFO +# debugging flags (optional) +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN +export PYTHONFAULTHANDLER=1 +export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export CUDA_LAUNCH_BLOCKING=0 + +# on your cluster you might need these: +# set the network interface +export NCCL_SOCKET_IFNAME="ens" +export FI_EFA_USE_DEVICE_RDMA=1 + +srun torchrun --nproc_per_node 4 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:29500 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora + diff --git a/policies/__init__.py b/policies/__init__.py new file mode 100644 index 000000000..124477031 --- /dev/null +++ b/policies/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .mixed_precision import * +from .wrapping import * +from .activation_checkpointing_functions import apply_fsdp_checkpointing +from .anyprecision_optimizer import AnyPrecisionAdamW diff --git a/policies/activation_checkpointing_functions.py b/policies/activation_checkpointing_functions.py new file mode 100644 index 000000000..0a1e31f42 --- /dev/null +++ b/policies/activation_checkpointing_functions.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch +import os +import torch.distributed as dist +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) + +from transformers.models.t5.modeling_t5 import T5Block +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from functools import partial + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) + + +def apply_fsdp_checkpointing(model): + """apply activation checkpointing to model + returns None as model is updated directly + """ + print(f"--> applying fdsp activation checkpointing...") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) diff --git a/policies/anyprecision_optimizer.py b/policies/anyprecision_optimizer.py new file mode 100644 index 000000000..22b0ca001 --- /dev/null +++ b/policies/anyprecision_optimizer.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# AnyPrecisionAdamW: a flexible precision AdamW optimizer +# with optional Kahan summation for high precision weight updates. +# Allows direct control over momentum, variance and auxiliary compensation +# buffer dtypes. +# Optional Kahan summation is used to offset precision reduction for +# the weight updates. This allows full training in BFloat16 (equal or +# better than FP32 results in many cases) due to high precision weight upates. + +import torch +from torch.optim.optimizer import Optimizer + + +class AnyPrecisionAdamW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + use_kahan_summation=False, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + compensation_buffer_dtype=torch.bfloat16, + ): + """ + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # Any Precision specific + use_kahan_summation = creates auxiliary buffer to ensure high precision + model param updates (default: False) + momentum_dtype = dtype for momentum (default: BFloat32) + variance_dtype = dtype for uncentered variance (default: BFloat16) + compensation_buffer_dtype = dtype for Kahan summation + buffer (default: BFloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + Defaults are variance in BF16, Momentum in FP32. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + """ + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + use_kahan_summation=use_kahan_summation, + momentum_dtype=momentum_dtype, + variance_dtype=variance_dtype, + compensation_buffer_dtype=compensation_buffer_dtype, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + if closure is not None: + with torch.enable_grad(): + # to fix linter, we do not keep the returned loss for use atm. + closure() + + for group in self.param_groups: + + beta1, beta2 = group["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + use_kahan_summation = group["use_kahan_summation"] + + momentum_dtype = group["momentum_dtype"] + variance_dtype = group["variance_dtype"] + compensation_buffer_dtype = group["compensation_buffer_dtype"] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + + state["step"] = torch.tensor(0.0) + + # momentum - EMA of gradient values + state["exp_avg"] = torch.zeros_like( + p, + dtype=momentum_dtype, + ) + + # variance uncentered - EMA of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, + dtype=variance_dtype, + ) + + # optional Kahan summation - accumulated error tracker + if use_kahan_summation: + state["compensation"] = torch.zeros_like( + p, + dtype=compensation_buffer_dtype, + ) + + # main processing ------------------------- + + # update the steps for each param group update + state["step"] += 1 + step = state["step"] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + grad = p.grad + + # weight decay, AdamW style + if weight_decay: + p.data.mul_(1 - lr * weight_decay) + + # update momentum + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # update uncentered variance + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # adjust using bias1 + bias_correction1 = 1 - beta1**step + + step_size = lr / bias_correction1 + + # adjust using bias2 + denom_correction = (1 - beta2**step) ** 0.5 # avoids math import + + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) + + # lr update to compensation + if use_kahan_summation: + compensation = state["compensation"] + + compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) + + # update weights with compensation (Kahan summation) + # save error back to compensation for next iteration + temp_buffer = p.detach().clone() + p.data.add_(compensation) + compensation.add_(temp_buffer.sub_(p.data)) + + else: + # usual AdamW updates + p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) \ No newline at end of file diff --git a/policies/mixed_precision.py b/policies/mixed_precision.py new file mode 100644 index 000000000..410ee392e --- /dev/null +++ b/policies/mixed_precision.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch + +from torch.distributed.fsdp import ( + # FullyShardedDataParallel as FSDP, + # CPUOffload, + MixedPrecision, + # BackwardPrefetch, + # ShardingStrategy, +) + +# requires grad scaler in main loop +fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, +) + +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, + cast_forward_inputs=True, +) + +bfSixteen_mixed = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, +) + +fp32_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, +) diff --git a/policies/wrapping.py b/policies/wrapping.py new file mode 100644 index 000000000..d9fadc334 --- /dev/null +++ b/policies/wrapping.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch.distributed as dist +import torch.nn as nn +import torch + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, + CPUOffload, + BackwardPrefetch, + MixedPrecision, +) +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +import functools +from typing import Type + + +def get_size_policy(min_params=1e8): + num_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_params + ) + return num_wrap_policy + + +def get_llama_wrapper(): + """we register our main layer class and use the fsdp transformer wrapping policy + ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers + """ + # ==== use new transformer wrapper + + llama_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + return llama_auto_wrap_policy diff --git a/quickstart.ipynb b/quickstart.ipynb new file mode 100644 index 000000000..21ef027dc --- /dev/null +++ b/quickstart.ipynb @@ -0,0 +1,672 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quick Start Notebook\n", + "\n", + "This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.\n", + "\n", + "### Step 0: Install pre-requirements and convert checkpoint\n", + "\n", + "The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.\n", + "The conversion can be achieved by running the `convert_llama_weights_to_hf.py` script provided with the transformer package.\n", + "Given that the original checkpoint resides under `models/7B` we can install all requirements and convert the checkpoint with:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# %%bash\n", + "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n", + "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n", + "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1: Load the model\n", + "\n", + "Point model_id to model weight folder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "===================================BUG REPORT===================================\n", + "Welcome to bitsandbytes. For bug reports, please run\n", + "\n", + "python -m bitsandbytes\n", + "\n", + " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", + "================================================================================\n", + "bin /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so\n", + "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so\n", + "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n", + "CUDA SETUP: Detected CUDA version 112\n", + "CUDA SETUP: Loading binary /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /data/home/mreso/miniconda3/envs/llama did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n", + " warn(msg)\n", + "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/efa/lib')}\n", + " warn(msg)\n", + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n", + "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.09s/it]\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import LlamaForCausalLM, LlamaTokenizer\n", + "\n", + "model_id=\"./models_hf/7B\"\n", + "\n", + "tokenizer = LlamaTokenizer.from_pretrained(model_id)\n", + "\n", + "model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2: Load the preprocessed dataset\n", + "\n", + "We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset samsum (/data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)\n", + "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-b14554a76c1c7ecd.arrow\n", + "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e40e61e15ebeb527.arrow\n", + "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e08ac9e1b792e7ba.arrow\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import os\n", + "import sys\n", + "from utils.dataset_utils import get_preprocessed_dataset\n", + "from configs.datasets import samsum_dataset\n", + "\n", + "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: Check base model\n", + "\n", + "Run the base model on an example input:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Summarize this dialog:\n", + "A: Hi Tom, are you busy tomorrow’s afternoon?\n", + "B: I’m pretty sure I am. What’s up?\n", + "A: Can you go with me to the animal shelter?.\n", + "B: What do you want to do?\n", + "A: I want to get a puppy for my son.\n", + "B: That will make him so happy.\n", + "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", + "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", + "A: I'll get him one of those little dogs.\n", + "B: One that won't grow up too big;-)\n", + "A: And eat too much;-))\n", + "B: Do you know which one he would like?\n", + "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", + "B: I bet you had to drag him away.\n", + "A: He wanted to take it home right away ;-).\n", + "B: I wonder what he'll name it.\n", + "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", + "---\n", + "Summary:\n", + "A: Hi Tom, are you busy tomorrow’s afternoon?\n", + "B: I’m pretty sure I am. What’s up?\n", + "A: Can you go with me to the animal shelter?.\n", + "B: What do you want to do?\n", + "A: I want to get a puppy for my son.\n", + "B: That will make him so happy.\n", + "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", + "B\n" + ] + } + ], + "source": [ + "eval_prompt = \"\"\"\n", + "Summarize this dialog:\n", + "A: Hi Tom, are you busy tomorrow’s afternoon?\n", + "B: I’m pretty sure I am. What’s up?\n", + "A: Can you go with me to the animal shelter?.\n", + "B: What do you want to do?\n", + "A: I want to get a puppy for my son.\n", + "B: That will make him so happy.\n", + "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", + "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", + "A: I'll get him one of those little dogs.\n", + "B: One that won't grow up too big;-)\n", + "A: And eat too much;-))\n", + "B: Do you know which one he would like?\n", + "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", + "B: I bet you had to drag him away.\n", + "A: He wanted to take it home right away ;-).\n", + "B: I wonder what he'll name it.\n", + "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", + "---\n", + "Summary:\n", + "\"\"\"\n", + "\n", + "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the base model only repeats the conversation.\n", + "\n", + "### Step 4: Prepare model for PEFT\n", + "\n", + "Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199\n" + ] + } + ], + "source": [ + "model.train()\n", + "\n", + "def create_peft_config(model):\n", + " from peft import (\n", + " get_peft_model,\n", + " LoraConfig,\n", + " TaskType,\n", + " prepare_model_for_int8_training,\n", + " )\n", + "\n", + " peft_config = LoraConfig(\n", + " task_type=TaskType.CAUSAL_LM,\n", + " inference_mode=False,\n", + " r=8,\n", + " lora_alpha=32,\n", + " lora_dropout=0.05,\n", + " target_modules = [\"q_proj\", \"v_proj\"]\n", + " )\n", + "\n", + " # prepare int-8 model for training\n", + " model = prepare_model_for_int8_training(model)\n", + " model = get_peft_model(model, peft_config)\n", + " model.print_trainable_parameters()\n", + " return model, peft_config\n", + "\n", + "# create peft config\n", + "model, lora_config = create_peft_config(model)\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Step 5: Define an optional profiler" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainerCallback\n", + "from contextlib import nullcontext\n", + "enable_profiler = False\n", + "output_dir = \"tmp/llama-output\"\n", + "\n", + "config = {\n", + " 'lora_config': lora_config,\n", + " 'learning_rate': 1e-4,\n", + " 'num_train_epochs': 1,\n", + " 'gradient_accumulation_steps': 2,\n", + " 'per_device_train_batch_size': 2,\n", + " 'gradient_checkpointing': False,\n", + "}\n", + "\n", + "# Set up profiler\n", + "if enable_profiler:\n", + " wait, warmup, active, repeat = 1, 1, 2, 1\n", + " total_steps = (wait + warmup + active) * (1 + repeat)\n", + " schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)\n", + " profiler = torch.profiler.profile(\n", + " schedule=schedule,\n", + " on_trace_ready=torch.profiler.tensorboard_trace_handler(f\"{output_dir}/logs/tensorboard\"),\n", + " record_shapes=True,\n", + " profile_memory=True,\n", + " with_stack=True)\n", + " \n", + " class ProfilerCallback(TrainerCallback):\n", + " def __init__(self, profiler):\n", + " self.profiler = profiler\n", + " \n", + " def on_step_end(self, *args, **kwargs):\n", + " self.profiler.step()\n", + "\n", + " profiler_callback = ProfilerCallback(profiler)\n", + "else:\n", + " profiler = nullcontext()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6: Fine tune the model\n", + "\n", + "Here, we fine tune the model for a single epoch which takes a bit more than an hour on a A100." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n", + "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", + " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", + "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", + " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [389/389 1:12:06, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
101.965000
201.845600
301.801100
401.780900
501.715400
601.697800
701.707600
801.713300
901.663900
1001.702700
1101.658800
1201.692400
1301.644900
1401.687900
1501.686600
1601.649600
1701.666900
1801.709200
1901.670400
2001.662700
2101.681300
2201.685500
2301.663400
2401.638300
2501.627400
2601.654300
2701.640900
2801.674700
2901.657300
3001.660200
3101.666600
3201.674500
3301.656200
3401.684300
3501.667900
3601.661400
3701.676800
3801.628100

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from transformers import default_data_collator, Trainer, TrainingArguments\n", + "\n", + "\n", + "\n", + "# Define training args\n", + "training_args = TrainingArguments(\n", + " output_dir=output_dir,\n", + " overwrite_output_dir=True,\n", + " bf16=True, # Use BF16 if available\n", + " # logging strategies\n", + " logging_dir=f\"{output_dir}/logs\",\n", + " logging_strategy=\"steps\",\n", + " logging_steps=10,\n", + " save_strategy=\"no\",\n", + " optim=\"adamw_torch_fused\",\n", + " max_steps=total_steps if enable_profiler else -1,\n", + " **{k:v for k,v in config.items() if k != 'lora_config'}\n", + ")\n", + "\n", + "with profiler:\n", + " # Create Trainer instance\n", + " trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " data_collator=default_data_collator,\n", + " callbacks=[profiler_callback] if enable_profiler else [],\n", + " )\n", + "\n", + " # Start training\n", + " trainer.train()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 7:\n", + "Save model checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_pretrained(output_dir)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 8:\n", + "Try the fine tuned model on the same example again to see the learning progress:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Summarize this dialog:\n", + "A: Hi Tom, are you busy tomorrow’s afternoon?\n", + "B: I’m pretty sure I am. What’s up?\n", + "A: Can you go with me to the animal shelter?.\n", + "B: What do you want to do?\n", + "A: I want to get a puppy for my son.\n", + "B: That will make him so happy.\n", + "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", + "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", + "A: I'll get him one of those little dogs.\n", + "B: One that won't grow up too big;-)\n", + "A: And eat too much;-))\n", + "B: Do you know which one he would like?\n", + "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", + "B: I bet you had to drag him away.\n", + "A: He wanted to take it home right away ;-).\n", + "B: I wonder what he'll name it.\n", + "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", + "---\n", + "Summary:\n", + "A wants to get a puppy for his son. He took him to the animal shelter last Monday. He showed him one that he really liked. A will name it after his dead hamster - Lemmy.\n" + ] + } + ], + "source": [ + "model.eval()\n", + "with torch.no_grad():\n", + " print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "vscode": { + "interpreter": { + "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..cba786e07 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.0.1+cu118 +accelerate +appdirs +loralib +bitsandbytes==0.39.1 +black +black[jupyter] +datasets +fire +git+https://github.com/huggingface/peft.git +transformers +sentencepiece +py7zr +scipy + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 000000000..fc634e2b3 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .memory_utils import MemoryTrace +from .dataset_utils import * +from .fsdp_utils import fsdp_auto_wrap_policy +from .train_utils import * \ No newline at end of file diff --git a/utils/config_utils.py b/utils/config_utils.py new file mode 100644 index 000000000..1ba10419d --- /dev/null +++ b/utils/config_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import inspect +from dataclasses import fields +from peft import ( + LoraConfig, + AdaptionPromptConfig, + PrefixTuningConfig, +) + +import configs.datasets as datasets +from configs import lora_config, llama_adapter_config, prefix_config, train_config +from .dataset_utils import DATASET_PREPROC + + +def update_config(config, **kwargs): + if isinstance(config, (tuple, list)): + for c in config: + update_config(c, **kwargs) + else: + for k, v in kwargs.items(): + if hasattr(config, k): + setattr(config, k, v) + elif "." in k: + # allow --some_config.some_param=True + config_name, param_name = k.split(".") + if type(config).__name__ == config_name: + if hasattr(config, param_name): + setattr(config, param_name, v) + else: + # In case of specialized config we can warm user + print(f"Warning: {config_name} does not accept parameter: {k}") + elif isinstance(config, train_config): + print(f"Warning: unknown parameter {k}") + + +def generate_peft_config(train_config, kwargs): + configs = (lora_config, llama_adapter_config, prefix_config) + peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) + names = tuple(c.__name__.rstrip("_config") for c in configs) + + assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" + + config = configs[names.index(train_config.peft_method)] + update_config(config, **kwargs) + params = {k.name: getattr(config, k.name) for k in fields(config)} + peft_config = peft_configs[names.index(train_config.peft_method)](**params) + + return peft_config + + +def generate_dataset_config(train_config, kwargs): + names = tuple(DATASET_PREPROC.keys()) + + assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" + + dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset] + update_config(dataset_config, **kwargs) + + return dataset_config \ No newline at end of file diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py new file mode 100644 index 000000000..9f2c0223d --- /dev/null +++ b/utils/dataset_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch + +from functools import partial + +from ft_datasets import ( + get_grammar_dataset, + get_alpaca_dataset, + get_samsum_dataset, +) +from typing import Optional + + +DATASET_PREPROC = { + "alpaca_dataset": partial(get_alpaca_dataset, max_words=224), + "grammar_dataset": get_grammar_dataset, + "samsum_dataset": get_samsum_dataset, +} + + +def get_preprocessed_dataset( + tokenizer, dataset_config, split: str = "train" +) -> torch.utils.data.Dataset: + if not dataset_config.dataset in DATASET_PREPROC: + raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") + + def get_split(): + return ( + dataset_config.train_split + if split == "train" + else dataset_config.test_split + ) + + return DATASET_PREPROC[dataset_config.dataset]( + dataset_config, + tokenizer, + get_split(), + ) diff --git a/utils/fsdp_utils.py b/utils/fsdp_utils.py new file mode 100644 index 000000000..e7ed13d2a --- /dev/null +++ b/utils/fsdp_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +def fsdp_auto_wrap_policy(model, transformer_layer_name): + import functools + import os + + from accelerate import FullyShardedDataParallelPlugin + from transformers.models.t5.modeling_t5 import T5Block + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + # FullyShardedDataParallelPlugin.get_module_class_from_name( + # model, transformer_layer_name + # ), + ), + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy \ No newline at end of file diff --git a/utils/memory_utils.py b/utils/memory_utils.py new file mode 100644 index 000000000..89d8c8f13 --- /dev/null +++ b/utils/memory_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import gc +import os +import sys +import threading + +import numpy as np +import psutil +import torch + +def byte2gb(x): + return int(x / 2**30) +# This context manager is used to track the peak memory usage of the process +class MemoryTrace: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.cuda.memory_allocated()) + self.process = psutil.Process() + self.cpu_begin = byte2gb(self.cpu_mem_used()) + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = byte2gb(torch.cuda.memory_allocated()) + self.peak = byte2gb(torch.cuda.max_memory_allocated()) + cuda_info = torch.cuda.memory_stats() + self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) + self.m_cuda_ooms = cuda_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") \ No newline at end of file diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 000000000..d1e61fb93 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import sys +from typing import List + +import fire +import torch +import transformers +from datasets import load_dataset +from tqdm import tqdm +""" +Unused imports: +import torch.nn as nn +import bitsandbytes as bnb +""" +from torch.nn import functional as F +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, + set_peft_model_state_dict, +) +from transformers import LlamaForCausalLM, LlamaTokenizer +from torch.distributed.fsdp import StateDictType +import torch.distributed as dist +from pkg_resources import packaging +from .memory_utils import MemoryTrace +import model_checkpointing +import torch.cuda.nccl as nccl +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from pathlib import Path +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper + +scaler = ShardedGradScaler() + + + + +def set_tokenizer_params(tokenizer: LlamaTokenizer): + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" + +# Converting Bytes to Megabytes +def byte2mb(x): + return int(x / 2**20) + +def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None): + """ + Trains the model on the given dataloader + + Args: + model: The model to be trained + train_dataloader: The dataloader containing the training data + optimizer: The optimizer used for training + lr_scheduler: The learning rate scheduler + gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation + num_epochs: The number of epochs to train for + local_rank: The rank of the current node in a distributed setting + train_config: The training configuration + eval_dataloader: The dataloader containing the eval data + tokenizer: tokenizer used in the eval for decoding the predicitons + + Returns: results dictionary containing average training and validation perplexity and loss + """ + # Create a gradient scaler for fp16 + scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None + + train_prep = [] + train_loss = [] + val_prep = [] + val_loss =[] + results = {} + best_val_loss = float("inf") + for epoch in range(train_config.num_epochs): + with MemoryTrace() as memtrace: # track the memory usage + model.train() + total_loss = 0.0 + data_set_len = 0 + + for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")): + for key in batch.keys(): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + elif not train_config.quantization: + batch[key] = batch[key].to('cuda') + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + total_loss += loss.detach().float() + first_key = next(iter(batch)) + data_set_len += len(batch[first_key]) + if train_config.use_fp16: + # if fp16 is enabled, use gradient scaler to handle gradient update + scaler.scale(loss).backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + # regular backpropagation when fp16 is not used + loss.backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + print(f"\n step {step} is completed and loss is {loss.detach().float()}") + + # Reducing total_loss across all devices if there's more than one CUDA device + if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + train_epoch_loss = total_loss / data_set_len + train_perplexity = torch.exp(train_epoch_loss) + + train_prep.append(train_perplexity) + train_loss.append(train_epoch_loss) + + print(f"Max CUDA memory allocated was {memtrace.peak} GB") + print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + + if train_config.run_validation: + eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer) + if train_config.save_model and eval_epoch_loss < best_val_loss: + + if train_config.use_peft: + + print(f"we are in the saving the PEFT modules") + model.save_pretrained(train_config.output_dir) + print(f"PEFT modules are saved in {train_config.output_dir} directory") + + else: + if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + + model_checkpointing.save_model_checkpoint( + model, optimizer, rank, train_config, epoch=1 + ) + elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + print(" we are about to save the models *******") + + model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config) + if train_config.save_optimizer: + model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + + if not train_config.use_peft and train_config.save_optimizer: + model_checkpointing.save_optimizer_checkpoint( + model, optimizer, rank, train_config, epoch=1 + ) + + + if local_rank == 0 and eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + print(f"best eval loss on epoch {epoch} is {best_val_loss}") + val_loss.append(best_val_loss) + val_prep.append(eval_ppl) + + print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}") + + avg_train_prep = sum(train_prep)/len(train_prep) + avg_train_loss = sum(train_loss)/len(train_loss) + if train_config.run_validation: + avg_eval_prep = sum(val_prep)/len(val_prep) + avg_eval_loss = sum(val_loss)/len(val_loss) + + results['avg_train_prep'] = avg_train_prep + results['avg_train_loss'] = avg_train_loss + if train_config.run_validation: + results['avg_eval_prep'] = avg_eval_prep + results['avg_eval_loss'] = avg_eval_loss + + + return results + +def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): + """ + Evaluates the model on the given dataloader + + Args: + model: The model to evaluate + eval_dataloader: The dataloader containing the evaluation data + local_rank: The rank of the current node in a distributed setting + tokenizer: The tokenizer used to decode predictions + + Returns: eval_ppl, eval_epoch_loss + """ + model.eval() + eval_preds = [] + eval_loss = 0.0 # Initialize evaluation loss + eval_dataset_len = 0 + with MemoryTrace() as memtrace: + for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")): + for key in batch.keys(): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + else: + batch[key] = batch[key].to('cuda') + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + # Forward pass and compute loss + outputs = model(**batch) + loss = outputs.loss + eval_loss += loss.detach().float() + first_key = next(iter(batch)) + eval_dataset_len+= len(batch[first_key]) + + # Decode predictions and add to evaluation predictions list + preds = torch.argmax(outputs.logits, -1) + eval_preds.extend( + tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) + ) + + # If there's more than one CUDA device, reduce evaluation loss across all devices + if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + + # Compute average loss and perplexity + eval_epoch_loss = eval_loss / eval_dataset_len + eval_ppl = torch.exp(eval_epoch_loss) + + # Print evaluation metrics + print(f" {eval_ppl=} {eval_epoch_loss=}") + return eval_ppl, eval_epoch_loss + +def freeze_transformer_layers(model, num_layer): + for i, layer in enumerate(model.model.layers): + if i < num_layer: + for param in layer.parameters(): + param.requires_grad = False + + +def check_frozen_layers_peft_model(model): + for i, layer in enumerate(model.base_model.model.model.layers): + for name, param in layer.named_parameters(): + print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") + + +def setup(): + """Initialize the process group for distributed training""" + dist.init_process_group("nccl") + + +def setup_environ_flags(rank): + """Set environment flags for debugging purposes""" + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + if rank == 0: + print(f"--> Running with torch dist debug set to detail") + + +def cleanup(): + """Clean up the process group after training""" + dist.destroy_process_group() + + +def clear_gpu_cache(rank=None): + """Clear the GPU cache for all ranks""" + if rank == 0: + print(f"Clearing GPU cache for all ranks") + torch.cuda.empty_cache() + + +def get_parameter_dtypes(model): + """Get the data types of model parameters""" + parameter_dtypes = {} + for name, parameter in model.named_parameters(): + parameter_dtypes[name] = parameter.dtype + return parameter_dtypes + +def print_model_size(model, config, rank: int = 0) -> None: + """ + Print model name, the number of trainable parameters and initialization time. + + Args: + model: The PyTorch model. + model_name (str): Name of the model. + init_time_start (float): Initialization start time. + init_time_end (float): Initialization end time. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + print(f"--> Model {config.model_name}") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") + + + + +def get_policies(cfg, rank): + """Get the policies for mixed precision and fsdp wrapping""" + + verify_bfloat_support = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and packaging.version.parse(torch.version.cuda).release >= (11, 0) + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) + + + mixed_precision_policy = None + wrapping_policy = None + + # Mixed precision + if cfg.mixed_precision: + bf16_ready = verify_bfloat_support + + if bf16_ready and not cfg.use_fp16: + mixed_precision_policy = bfSixteen_mixed + if rank == 0: + print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + elif cfg.use_fp16: + mixed_precision_policy = fpSixteen + if rank == 0: + print(f"FP16 enabled") + else: + print(f"bFloat16 support not present. Using FP32, and not mixed precision") + wrapping_policy = get_llama_wrapper() + return mixed_precision_policy, wrapping_policy