diff --git a/.github/Pythia_saturation.png b/.github/Pythia_saturation.png new file mode 100644 index 0000000..f68e0ed Binary files /dev/null and b/.github/Pythia_saturation.png differ diff --git a/.github/TinyLlama_logo.png b/.github/TinyLlama_logo.png new file mode 100644 index 0000000..3f2c570 Binary files /dev/null and b/.github/TinyLlama_logo.png differ diff --git a/.github/llama2-training.png b/.github/llama2-training.png new file mode 100644 index 0000000..c4e4993 Binary files /dev/null and b/.github/llama2-training.png differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..59b43e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__ +.idea +.DS_Store +*.egg-info +build +.venv +.vscode + +# data +data +checkpoints +out +wandb + +tests/original_falcon_40b.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..fe60df9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2023] Lightning AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PRETRAIN.md b/PRETRAIN.md new file mode 100644 index 0000000..46fe588 --- /dev/null +++ b/PRETRAIN.md @@ -0,0 +1,81 @@ +## Pretrain TinyLlama + +### Installation +We expect you have CUDA 11.8 installed. +#### Install Pytorch Nightly. +```bash +pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' +``` +#### Build XFormers from Source +Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source. +```bash +pip uninstall ninja -y && install ninja -U +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +``` + + +#### Install Flash-Attention 2 and other fused operators: +```bash +git clone https://github.com/Dao-AILab/flash-attention +cd flash-attention +python setup.py install +cd csrc/rotary && pip install . +cd ../layer_norm && pip install . +cd ../xentropy && pip install . +cd ../.. && rm -rf flash-attention +``` +#### Install Remaining Dependencies +``` +pip install -r requirements.txt tokenizers sentencepiece +``` +to install other dependencies. +It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. + +Then you are ready to go 🎉! + +### Data Preparation + +#### Download Datasets +Download the Slimpajama and Starcoderdata datasets to your chosen directory. +```bash +cd /path/to/dataset +git lfs install +git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B +git clone https://huggingface.co/datasets/bigcode/starcoderdata +``` +The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB. + +#### Tokenize data +Use the provided scripts to tokenize the datasets and divide them into chunks. +```bash +python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 +python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split validation --percentage 1.0 +python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 +``` +The processed data will take 1.8T storage. + +### Pretraining +If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands: + +On node 1: +``` +lightning run model \ + --node-rank=0 \ + --main-address=172.16.101.5 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star +``` +On node 2: +``` +lightning run model \ + --node-rank=1 \ + --main-address=172.16.101.5 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star +``` +You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ff540f4 --- /dev/null +++ b/README.md @@ -0,0 +1,159 @@ +
+ +# TinyLlama-1.1B +English | [中文](README_zh-CN.md) +
+ +The TinyLlama project aims to **pretrain** a **1.1B Llama model on 3 trillion tokens**. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01. + +
+ +
+ +We adopted exactly the same architecture and tokenizer as Llama 2. This means TinyLlama can be plugged and played in many open-source projects built upon Llama. Besides, TinyLlama is compact with only 1.1B parameters. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint. + + +#### Releases Schedule +We will be rolling out intermediate checkpoints following the below schedule. We also include some baseline models for comparison. + +| Date | HF Checkpoint | Tokens | Step | HellaSwag Acc_norm | +|------------|-------------------------------------------------|--------|------|---------------------| +| Baseline | [StableLM-Alpha-3B](https://huggingface.co/stabilityai/stablelm-base-alpha-3b)| 800B | -- | 38.31 | +| Baseline | [Pythia-1B-intermediate-step-50k-105b](https://huggingface.co/EleutherAI/pythia-1b/tree/step50000) | 105B | 50k | 42.04 | +| Baseline | [Pythia-1B](https://huggingface.co/EleutherAI/pythia-1b) | 300B | 143k | 47.16 | +| 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) | 105B | 50k | 43.50 | +| 2023-09-16 | -- | 500B | -- | -- | +| 2023-10-01 | -- | 1T | -- | -- | +| 2023-10-16 | -- | 1.5T | -- | -- | +| 2023-10-31 | -- | 2T | -- | -- | +| 2023-11-15 | -- | 2.5T | -- | -- | +| 2023-12-01 | -- | 3T | -- | -- | + + + + + + +It can be observed that TinyLlama has so far progressed well 🎉🎉. + +Meanwhile, you can track the live cross entropy loss [here](https://wandb.ai/lance777/lightning_logs/reports/metric-train_loss-23-09-02-15-26-17---Vmlldzo1MjkzNzMw?accessToken=9843chbl7rfi1w03hxttpcnbo9z8t6088pw3ddn4h8teunaq0cy7j8hw9c5i02ve). + +## Potential Usecase +Tiny but strong language models are useful for many applications. Here are some potential usecases: +- Assisting speculative decoding of larger models. (See this [tutorial](https://twitter.com/karpathy/status/1697318534555336961) by Andrej Karpathy) +- Deployment on edge devices with restricted memory and computational capacities, for functionalities like real-time machine translation without an internet connection (the 4bit-quantized TinyLlama-1.1B's weight only takes up 550MB RAM). +- Enabling real-time dialogue generation in video games. + +Moreover, our code can be a valuable **reference for enthusiasts keen on pretraining language models under 5 billion parameters** without diving too early into [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). + +## Training Details +Below are some details of our training setup: + +| Setting | Description | +|---------------------------------|----------------------------------------------------------------| +| Parameters | 1.1B | +| Attention Variant | Grouped Query Attention | +| Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632| +| Sequence Length | 2048 | +| Batch Size | 2 million tokens (2048 * 1024) | +| Learning Rate | 4e-4 | +| Learning Rate Schedule | Cosine with 2000 warmup steps | +| Training Data | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) | +| Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata | +| Combined Dataset Size | 1 trillion tokens | +| Total Tokens During Training | 3 trillion (3 epochs/1430k steps) | +| Natural Language to Code Ratio | 7:3 | +| Hardware | 16 A100-40G GPUs | + + + + + + +## Blazingly Fast +Our codebase supports the following features: +- multi-gpu and multi-node distributed training with FSDP. +- flash attention 2. +- fused layernorm. +- fused swiglu. +- fused cross entropy loss . +- fused rotary positional embedding. + +Thanks to those optimizations, we achieve a throughput of **24k** tokens per second per A100-40G GPU, which translates to **56% model flops utilization** without activation checkpointing (We expect the MFU to be even higher on A100-80G). It means you can train a chinchilla-optimal TinyLlama (1.1B param, 22B tokens) in **32 hours with 8 A100**. Those optimizations also greatly reduce the memory footprint, allowing us to stuff our 1.1B model into 40GB GPU RAM and train with a per-gpu batch size of 16k tokens. **You can also pretrain TinyLlama on 3090/4090 GPUs with a smaller per-gpu batch size**. +Below is a comparison of the training speed of our codebase with that of Pythia and MPT. + + +| Model | A100 GPU hours taken on 300B tokens| +|-----------------------------------|------------------------------------| +|TinyLlama-1.1B | 3456 | +|[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b) | 4830 | +|[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) | 7920 | + + The Pythia number comes from their [paper](https://arxiv.org/abs/2304.01373). The MPT number comes from [here](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b), in which they say MPT-1.3B " was trained on 440 A100-40GBs for about half a day" on 200B tokens. + +The fact that TinyLlama is a relatively small model with grouped query attention means it is also fast during inference. Below are some throughputs that we measure: + +| Framework | Device | Batch Size | Throughput | +|-----------|--------------|-----|-----------| +|[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM | 1| 71.8 tokens/sec | +|[vLLM](https://github.com/vllm-project/vllm) | One A40 GPU | | | + + +## Getting Started +Please refer to [PRETRAIN.md](PRETRAIN.md) for instructions on how to pretrain TinyLlama. + +## TODO +This project is still under active development. We are a really small team. Community feedback and contributions are highly appreciated. Here are some things we plan to work on: + - [ ] Add scripts for pretraining on other datasets. + - [ ] Sequence length extrapolation. + - [ ] Test the throughput on RTX 3090/4090. + - [ ] Add fine-tuning scripts. + - [ ] Properly evaluate the model on downstream tasks. + - [ ] A demo running on mobile phones. + - [ ] Explore retrieval-augmentation. + + +## Acknowledgements +This repository is built upon [lit-gpt](https://github.com/Lightning-AI/lit-gpt) and [flash-attention](https://github.com/Dao-AILab/flash-attention). Be sure to explore this fantastic open-source project if it's new to you! +``` +@online{lit-gpt, + author = {Lightning AI}, + title = {Lit-GPT}, + url = {https://github.com/Lightning-AI/lit-gpt}, + year = {2023}, +} +@article{dao2023flashattention2, + title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author ={Dao, Tri}, + year ={2023} +} +``` + +## Citation +This project is currently contributed by [Peiyuan Zhang](https://github.com/jzhang38), [Guangtao Zeng](https://github.com/ChaosCodes), [Tianduo Wang](https://github.com/TianduoWang) and [Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/). + +If you find our work valuable, please cite: + +``` +@online{tinyllama, + author = {Peiyuan Zhang, Guangtao Zeng, Tianduo Wang, Wei Lu}, + title = {TinyLlama}, + url = {https://github.com/jzhang38/TinyLlama}, + year = {2023}, + month = {Oct}, +} +``` + +## Frequently Asked Questions + +#### 1. Why would pretraining a 1.1B model for so long make sense? Doesn't it contradict the Chinchilla Scaling Law? + +The training loss curve of Llama 2 + +Above is the training loss curve taken from the Llama 2 paper. Here I quote from that paper: "We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation". That is why we believe pretraining a 1.1B model for 3T tokens is a reasonable thing to do. Even if the loss curve does not go down eventually, we can still study the phenomenon of saturation and learn something from it. + +#### 2. What does "saturation" mean? +Figure 10 of the Pythia paper + +The figure from the Pythia paper displays the LAMBADA accuracy plotted against the total training tokens (300B). The term "saturation" pertains specifically to the 70M and 160M models. Notably, even the 410M model does not saturate with 300B tokens, as it continues to show an increasing trend, similar to the trend of larger models. diff --git a/README_zh-CN.md b/README_zh-CN.md new file mode 100644 index 0000000..bff263d --- /dev/null +++ b/README_zh-CN.md @@ -0,0 +1,149 @@ +
+ +# TinyLlama-1.1B +[English](README.md) | 中文 +
+ +TinyLlama项目旨在在30万亿tokens上进行预训练,构建一个拥有11亿参数的Llama模型。经过精心优化,我们"仅"需16块A100-40G的GPU,便可在90天内完成这个任务🚀🚀。训练已于2023-09-01开始。 + + +
+ +
+ +我们采用了与Llama 2完全相同的架构和分词器。这意味着TinyLlama可以在许多基于Llama的开源项目中即插即用。此外,TinyLlama只有1.1B的参数,体积小巧,适用于需要限制计算和内存占用的多种应用。 + + +#### 发布时间表 + +我们会根据以下计划逐步发布中间checkpoint。我们也列了一些基线模型进行比较。 + + + +| Date | HF Checkpoint | Tokens | Step | HellaSwag Acc_norm | +|------------|-------------------------------------------------|--------|------|---------------------| +| Baseline | [StableLM-Alpha-3B](https://huggingface.co/stabilityai/stablelm-base-alpha-3b)| 800B | -- | 38.31 | +| Baseline | [Pythia-1B-intermediate-step-50k-105b](https://huggingface.co/EleutherAI/pythia-1b/tree/step50000) | 105B | 50k | 42.04 | +| Baseline | [Pythia-1B](https://huggingface.co/EleutherAI/pythia-1b) | 300B | 143k | 47.16 | +| 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) | 105B | 50k | 43.50 | +| 2023-09-16 | -- | 500B | -- | -- | +| 2023-10-01 | -- | 1T | -- | -- | +| 2023-10-16 | -- | 1.5T | -- | -- | +| 2023-10-31 | -- | 2T | -- | -- | +| 2023-11-15 | -- | 2.5T | -- | -- | +| 2023-12-01 | -- | 3T | -- | -- | + + + + + +从上面可以看出,TinyLlama目前的进展非常好🎉🎉。 + + +你也可以在[这里](https://wandb.ai/lance777/lightning_logs/reports/metric-train_loss-23-09-02-15-26-17---Vmlldzo1MjkzNzMw?accessToken=9843chbl7rfi1w03hxttpcnbo9z8t6088pw3ddn4h8teunaq0cy7j8hw9c5i02ve)实时跟踪TinyLlama的训练损失。 + +## 潜在场景 +小型但强大的语言模型对许多应用都很有用。以下是一些潜在的场景: +- 帮助对大型模型进行speculative decoding。 +- 在边缘装置上运行,比如离线的实时机器翻译 (TinyLlama的4比特量化版本的模型权重只需要550MB的内存)。 +- 在游戏中实现实时对话生成(因为还得给游戏本身留显存所以模型要小)。 + +此外,我们的代码可以给初学者做一个**入门预训练的简洁参考**。如果你要训练50亿以下参数的语言模型, 你其实不需要Megatron-LM。 + +## 训练细节 +以下是我们训练设置的一些细节: + +| Setting | Description | +|---------------------------------|----------------------------------------------------------------| +| Parameters | 1.1B | +| Attention Variant | Grouped Query Attention | +| Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632| +| Sequence Length | 2048 | +| Batch Size | 2 million tokens (2048 * 1024) | +| Learning Rate | 4e-4 | +| Learning Rate Schedule | Cosine with 2000 warmup steps | +| Training Data | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) | +| Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata | +| Combined Dataset Size | 1 trillion tokens | +| Total Tokens During Training | 3 trillion (3 epochs/143k steps) | +| Natural Language to Code Ratio | 7:3 | +| Hardware | 16 A100-40G GPUs | + + + + + + +## 速度极快 +我们的代码库支持以下特性: +- multi-gpu and multi-node distributed training with FSDP. +- flash attention 2. +- fused layernorm. +- fused swiglu. +- fused cross entropy loss . +- fused rotary positional embedding. + +有了这些优化, 我们可以达到**24k tokens/秒/A100**的训练速度,也就是56%的MFU(在A100-80G上的MFU会更高)。这个速度可以让你可以在**8个A100上用32小时训练一个chinchilla-optimial的模型**(11亿参数,220亿token)。这些优化也大大减少了显存占用, 我们可以把11亿参数的模型塞入40GB的GPU里面还能同时维持16k tokens的per-gpu batch size。只需要把batch size改小一点, 你就可以在**RTX 3090/4090**上面训练TinyLlama。 +下面是我们的代码库与Pythia和MPT的训练速度的比较。 + + +| Model | A100 GPU hours taken on 300B tokens| +|-----------------------------------|------------------------------------| +|TinyLlama-1.1B | 3456 | +|[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b) | 4830 | +|[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) | 7920 | + + Pythia的数字来自他们的论文。MPT的数字来自[这里](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b),作者说MPT-1.3B"was trained on 440 A100-40GBs for about half a day" on 200B tokens。 + +TinyLlama是一个相对较小的模型, 同时我们用了GQA, 这意味着它在推理期间也很快。以下是我们测量的一些推理速度: + +| Framework | Device | Batch Size | Throughput | +|-----------|--------------|-----|-----------| +|[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM | 1| 71.8 tokens/sec | +|[vLLM](https://github.com/vllm-project/vllm) | One A40 GPU | | | + + +## 开始训练 +请参考[PRETRAIN.md](PRETRAIN.md)。 + +## TODO +该项目仍在积极开发中。我们团队很小,非常欢迎社区的反馈和贡献。以下是我们计划进行的一些工作: + - [ ] Add scripts for pretraining on other datasets. + - [ ] Sequence length extrapolation. + - [ ] Test the throughput on RTX 3090/4090. + - [ ] Add fine-tuning scripts. + - [ ] Properly evaluate the model on downstream tasks. + - [ ] A demo running on mobile phones. + - [ ] Explore retrieval-augmentation. + + +## Acknowledgements +这个仓库基于出色的开源项目[lit-gpt](https://github.com/Lightning-AI/lit-gpt)和[flash-attention](https://github.com/Dao-AILab/flash-attention)构建. +``` +@online{lit-gpt, + author = {Lightning AI}, + title = {Lit-GPT}, + url = {https://github.com/Lightning-AI/lit-gpt}, + year = {2023}, +} +@article{dao2023flashattention2, + title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author ={Dao, Tri}, + year ={2023} +} +``` + +## Citation +此项目目前由[Peiyuan Zhang](https://github.com/jzhang38),[Guangtao Zeng](https://github.com/ChaosCodes),[Tianduo Wang](https://github.com/TianduoWang)和[Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/)贡献。 + +如果您觉得我们的工作有价值, 可以引用: + +``` +@online{tinyllama, + author = {Peiyuan Zhang, Guangtao Zeng, Tianduo Wang, Wei Lu}, + title = {TinyLlama}, + url = {https://github.com/jzhang38/TinyLlama}, + year = {2023}, + month = {Oct}, +} +``` diff --git a/lit_gpt/__init__.py b/lit_gpt/__init__.py new file mode 100644 index 0000000..a15c7f4 --- /dev/null +++ b/lit_gpt/__init__.py @@ -0,0 +1,20 @@ +from lit_gpt.model import GPT +from lit_gpt.config import Config +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss +from lightning_utilities.core.imports import RequirementCache + +if not bool(RequirementCache("torch>=2.1.0dev")): + raise ImportError( + "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the" + " repository README.md" + ) +_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") +if not bool(_LIGHTNING_AVAILABLE): + raise ImportError( + "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n" + f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" + ) + + +__all__ = ["GPT", "Config", "Tokenizer"] diff --git a/lit_gpt/adapter.py b/lit_gpt/adapter.py new file mode 100644 index 0000000..d0c9ba8 --- /dev/null +++ b/lit_gpt/adapter.py @@ -0,0 +1,283 @@ +"""Implementation of the paper: + +LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention +https://arxiv.org/abs/2303.16199 + +Port for Lit-GPT +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.model import KVCache, RoPECache, apply_rope + + +@dataclass +class Config(BaseConfig): + adapter_prompt_length: int = 10 + adapter_start_layer: int = 2 + + +class GPT(BaseModel): + """The implementation is identical to `lit_gpt.model.GPT` with the exception that + the `Block` saves the layer index and passes it down to the attention layer.""" + + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + self.adapter_kv_caches: List[KVCache] = [] + + def reset_cache(self) -> None: + super().reset_cache() + self.adapter_kv_caches.clear() + + def forward( + self, + idx: torch.Tensor, + max_seq_length: Optional[int] = None, + input_pos: Optional[torch.Tensor] = None, + lm_head_chunk_size: int = 0, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + B, T = idx.size() + use_kv_cache = input_pos is not None + + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + if use_kv_cache: # not relevant otherwise + assert ( + max_seq_length >= T + ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx) + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) + + cos, sin = self.rope_cache + if use_kv_cache: + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + cos = cos[:T] + sin = sin[:T] + mask = None + + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if not use_kv_cache: + for block in self.transformer.h: + x, *_ = block(x, (cos, sin), max_seq_length) + else: + self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) + self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)] + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i], self.adapter_kv_caches[i] = block( + x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] + ) + + x = self.transformer.ln_f(x) + + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, CausalSelfAttention): + module.reset_parameters() + + +class Block(nn.Module): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__() + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + adapter_kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: + n_1 = self.norm_1(x) + h, new_kv_cache, new_adapter_kv_cache = self.attn( + n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache + ) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = x + h + self.mlp(n_2) + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + x = x + h + x = x + self.mlp(self.norm_2(x)) + return x, new_kv_cache, new_adapter_kv_cache + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention + over the adaption prompt.""" + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__(config) + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) + self.reset_parameters() + self.block_idx = block_idx + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + adapter_kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # repeat k and v if necessary + if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) + # for MHA this is a no-op + k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + n_elem = int(self.config.rotary_percentage * self.config.head_size) + + cos, sin = rope + q_roped = apply_rope(q[..., :n_elem], cos, sin) + k_roped = apply_rope(k[..., :n_elem], cos, sin) + q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=2) + cache_v = torch.roll(cache_v, -1, dims=2) + k = cache_k.index_copy_(2, input_pos, k) + v = cache_v.index_copy_(2, input_pos, v) + kv_cache = k, v + + y = self.scaled_dot_product_attention(q, k, v, mask=mask) + + if self.block_idx >= self.config.adapter_start_layer: + aT = self.config.adapter_prompt_length + if adapter_kv_cache is not None: + ak, av = adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, C) + aqkv = self.attn(prefix) + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + adapter_kv_cache = (ak, av) + + amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) + ay = self.scaled_dot_product_attention(q, ak, av, amask) + y = y + self.gating_factor * ay + + y = y.reshape(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.proj(y) + + return y, kv_cache, adapter_kv_cache + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.gating_factor) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with older checkpoints.""" + if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_as_trainable(model: GPT) -> None: + """Sets `requires_grad=False` for all non-adapter weights.""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) + + +def adapter_filter(key: str, value: Any) -> bool: + return "adapter_wte" in key or "gating_factor" in key diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py new file mode 100644 index 0000000..25b0fc4 --- /dev/null +++ b/lit_gpt/adapter_v2.py @@ -0,0 +1,290 @@ +"""Implementation of the paper: + +LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model +https://arxiv.org/abs/2304.15010 + +Port for Lit-GPT +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.nn as nn +from typing_extensions import Self + +import lit_gpt +from lit_gpt.adapter import GPT as BaseModel +from lit_gpt.adapter import Block as BaseBlock +from lit_gpt.adapter import Config as BaseConfig +from lit_gpt.adapter import KVCache, RoPECache +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.model import apply_rope +from lit_gpt.utils import map_old_state_dict_weights + + +@dataclass +class Config(BaseConfig): + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.adapter_v2, self._mlp_class) + + +def adapter_filter(key: str, value: Any) -> bool: + adapter_substrings = ( + # regular adapter v1 parameters + "adapter_wte", + "gating_factor", + # adapter v2: new bias and scale used in Linear + "adapter_scale", + "adapter_bias", + # adapter v2: Norm parameters are now trainable + "norm_1", + "norm_2", + "ln_f", + ) + return any(s in key for s in adapter_substrings) + + +class AdapterV2Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, **kwargs) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) + self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.adapter_scale * (self.linear(x) + self.adapter_bias) + + def reset_parameters(self) -> None: + nn.init.zeros_(self.adapter_bias) + nn.init.ones_(self.adapter_scale) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + self.adapter_kv_caches: List[KVCache] = [] + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, CausalSelfAttention): + module.reset_parameters() + if isinstance(module, AdapterV2Linear): + module.reset_parameters() + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = {"lm_head.weight": "lm_head.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + def __init__(self, config: Config, block_idx: int) -> None: + """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for + parameter-efficient fine-tuning. + + *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for + query, key and value for each head) we can do this in a single pass with a single weight matrix. + """ + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) + # output projection + self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) + self.reset_parameters() + self.block_idx = block_idx + + self.config = config + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + adapter_kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # repeat k and v if necessary + if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) + # for MHA this is a no-op + k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + n_elem = int(self.config.rotary_percentage * self.config.head_size) + + cos, sin = rope + q_roped = apply_rope(q[..., :n_elem], cos, sin) + k_roped = apply_rope(k[..., :n_elem], cos, sin) + q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=2) + cache_v = torch.roll(cache_v, -1, dims=2) + k = cache_k.index_copy_(2, input_pos, k) + v = cache_v.index_copy_(2, input_pos, v) + kv_cache = k, v + + y = self.scaled_dot_product_attention(q, k, v, mask=mask) + + if self.block_idx >= self.config.adapter_start_layer: + aT = self.config.adapter_prompt_length + if adapter_kv_cache is not None: + ak, av = adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, C) + aqkv = self.attn(prefix) + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + adapter_kv_cache = (ak, av) + + amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) + ay = self.scaled_dot_product_attention(q, ak, av, amask) + y = y + self.gating_factor * ay + + y = y.reshape(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.proj(y) + + return y, kv_cache, adapter_kv_cache + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.gating_factor) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + # For compatibility with older checkpoints + if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_v2_as_trainable(model: GPT) -> None: + """Sets requires_grad=False for all non-adapter weights""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) diff --git a/lit_gpt/config.py b/lit_gpt/config.py new file mode 100644 index 0000000..7c91c11 --- /dev/null +++ b/lit_gpt/config.py @@ -0,0 +1,665 @@ +from dataclasses import dataclass +from typing import Any, Literal, Optional, Type + +import torch +from typing_extensions import Self + +import lit_gpt.model +from lit_gpt.utils import find_multiple + + +@dataclass +class Config: + org: str = "Lightning-AI" + name: str = "lit-GPT" + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = True + # to use multi-head attention (MHA), set this to `n_head` (default) + # to use multi-query attention (MQA), set this to 1 + # to use grouped-query attention (GQA), set this to a value in between + # Example with `n_head=4` + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ │ │ │ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + # MHA GQA MQA + # n_query_groups=4 n_query_groups=2 n_query_groups=1 + # + # credit https://arxiv.org/pdf/2305.13245.pdf + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" + intermediate_size: Optional[int] = None + condense_ratio: int = 1 + + def __post_init__(self): + # error checking + assert self.n_embd % self.n_head == 0 + # vocab size should be a power of 2 to be optimal on hardware. compute the closest value + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) + # compute the number of query groups + if self.n_query_groups is not None: + assert self.n_head % self.n_query_groups == 0 + else: + self.n_query_groups = self.n_head + # compute the intermediate size for MLP if not set + if self.intermediate_size is None: + if self._mlp_class == "LLaMAMLP": + raise ValueError("The config needs to set the `intermediate_size`") + self.intermediate_size = 4 * self.n_embd + + @property + def head_size(self) -> int: + return self.n_embd // self.n_head + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + conf_dict = name_to_config[name].copy() + conf_dict.update(kwargs) + return cls(**conf_dict) + + @property + def mlp_class(self) -> Type: + # `self._mlp_class` cannot be the type to keep the config json serializable + return getattr(lit_gpt.model, self._mlp_class) + + @property + def norm_class(self) -> Type: + # `self._norm_class` cannot be the type to keep the config json serializable + if self._norm_class == "RMSNorm": + from lit_gpt.rmsnorm import RMSNorm + + return RMSNorm + elif self._norm_class == "FusedRMSNorm": + from lit_gpt.rmsnorm import FusedRMSNorm + return FusedRMSNorm + return getattr(torch.nn, self._norm_class) + + +######################## +# Stability AI StableLM +######################## +configs = [ + # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json + dict(org="stabilityai", name="stablelm-base-alpha-3b", padding_multiple=512), + # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json + dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json + dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32, padding_multiple=512), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json + dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256), +] + +#################### +# EleutherAI Pythia +#################### +pythia = [ + # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json + dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128), + # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json + dict( + org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json + dict( + org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json + dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_layer=16, n_embd=2048, n_head=8, padding_multiple=128), + # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json + dict( + org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json + dict( + org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, n_head=32, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json + dict( + org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, n_embd=4096, n_head=32, padding_multiple=256 + ), + # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json + dict( + org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40, padding_multiple=512 + ), +] +configs.extend(pythia) +for c in pythia: + copy = c.copy() + copy["name"] = f"{c['name']}-deduped" + configs.append(copy) + + +#################################### +# togethercomputer RedPajama INCITE +#################################### +redpajama_incite = [ + # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json + dict( + org="togethercomputer", + name="RedPajama-INCITE-{}-3B-v1", + block_size=2048, + n_layer=32, + n_embd=2560, + n_head=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json + dict( + org="togethercomputer", + name="RedPajama-INCITE-7B-{}", + block_size=2048, + n_layer=32, + n_embd=4096, + n_head=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # this redirects to the checkpoint above. kept for those who had the old weights already downloaded + dict( + org="togethercomputer", + name="RedPajama-INCITE-{}-7B-v0.1", + block_size=2048, + n_layer=32, + n_embd=4096, + n_head=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), +] +for c in redpajama_incite: + for kind in ("Base", "Chat", "Instruct"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +################# +# TII UAE Falcon +################# +falcon = [ + # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json + dict( + org="tiiuae", + name="falcon-7b{}", + block_size=2048, + padded_vocab_size=65024, + n_layer=32, + n_head=71, + n_embd=4544, + rotary_percentage=1.0, + parallel_residual=True, + n_query_groups=1, + bias=False, + # this is not in the config, but in the original model implementation, only for this config + shared_attention_norm=True, + ), + # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json + dict( + org="tiiuae", + name="falcon-40b{}", + block_size=2048, + padded_vocab_size=65024, + n_layer=60, + n_head=128, + n_embd=8192, + rotary_percentage=1.0, + parallel_residual=True, + n_query_groups=8, + bias=False, + ), +] +for c in falcon: + for kind in ("", "-instruct"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +############################# +# StatNLP Research +############################# +tiny_LLaMA = [ + + # https://twitter.com/cwolferesearch/status/1691929174175264858 + dict( + org="StatNLP-research", + name="tiny_LLaMA_1b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ), + dict( + org="StatNLP-research", + name="tiny_LLaMA_120M", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=12, + n_head=12, + n_embd=768, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ), +] +configs.extend(tiny_LLaMA) + + +############################# +# OpenLM Research Open LLaMA +############################# +open_LLaMA = [ + # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_3b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=26, + n_head=32, + n_embd=3200, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=8640, + ), + # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_7b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_13b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), +] +configs.extend(open_LLaMA) + + +############### +# LMSYS Vicuna +############### +vicuna = [ + # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-7b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-13b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-33b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=60, + n_head=52, + n_embd=6656, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=17920, + ), + dict( + org="lmsys", + name="vicuna-7b-v1.5", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + dict( + org="lmsys", + name="vicuna-7b-v1.5-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + condense_ratio=4, + ), + dict( + org="lmsys", + name="vicuna-13b-v1.5", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + dict( + org="lmsys", + name="vicuna-13b-v1.5-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + condense_ratio=4, + ), +] +configs.extend(vicuna) + + +################# +# LMSYS LongChat +################# +long_chat = [ + # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json + dict( + org="lmsys", + name="longchat-7b-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + condense_ratio=8, + ), + # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json + dict( + org="lmsys", + name="longchat-13b-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + condense_ratio=8, + ), +] +configs.extend(long_chat) + + +###################### +# NousResearch Hermes +###################### +nous_research = [ + # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json + dict( + org="NousResearch", + name="Nous-Hermes-13b", + block_size=2048, + padded_vocab_size=32001, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ) +] +configs.extend(nous_research) + + +############### +# Meta LLaMA 2 +############### +llama_2 = [ + # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-7b{}-hf", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + dict( + org="meta-llama", + name="CodeLlama-2-7b-hf", + block_size=4096, + vocab_size=32016, + padded_vocab_size=32016, + padding_multiple=64, + n_layer=32, + n_head=32, + n_embd=4096, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-13b{}-hf", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-70b{}-hf", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), +] +for c in llama_2: + for kind in ("", "-chat"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +########################## +# Stability AI FreeWilly2 +########################## +freewilly_2 = [ + # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json + dict( + org="stabilityai", + name="FreeWilly2", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ) +] +configs.extend(freewilly_2) + + +name_to_config = {config["name"]: config for config in configs} diff --git a/lit_gpt/fused_cross_entropy.py b/lit_gpt/fused_cross_entropy.py new file mode 100644 index 0000000..ddd8350 --- /dev/null +++ b/lit_gpt/fused_cross_entropy.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023, Tri Dao. + +import torch +import torch.nn as nn +import xentropy_cuda_lib + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + """ + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + ctx.total_classes = world_size * vocab_size + + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels == ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward( + logits, labels_local, smoothing, world_size * vocab_size + ) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty( + world_size, batch, dtype=lse_local.dtype, device=lse_local.device + ) + torch.distributed.all_gather_into_tensor( + lse_allgather, lse_local.contiguous(), group=process_group + ) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") + lse_local = lse_allgather[ + rank_per_sample, torch.arange(batch, device=lse_allgather.device) + ] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( + lse - lse_allgather.sum(dim=0) + ) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels == ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward( + grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes + ) + return grad_logits, None, None, None, None, None, None + + +class FusedCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + inplace_backward=True, + process_group=None, + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + self.process_group = process_group + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + if len(input.shape) == 3: + input = input.view(-1, input.size(-1)) + target = target.view(-1) + loss = SoftmaxCrossEntropyLossFn.apply( + input, + target, + self.label_smoothing, + self.ignore_index, + self.inplace_backward, + self.process_group, + ) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss \ No newline at end of file diff --git a/lit_gpt/fused_rotary_embedding.py b/lit_gpt/fused_rotary_embedding.py new file mode 100644 index 0000000..7ac9da5 --- /dev/null +++ b/lit_gpt/fused_rotary_embedding.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023, Tri Dao. + +import math +from typing import Optional, Tuple + +import rotary_emb +import torch +from einops import rearrange, repeat + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) + out = torch.empty_like(x) if not inplace else x + out_ro = out[..., :rotary_dim] + if inplace: + o1, o2 = x1, x2 + else: + o1, o2 = ( + out_ro.chunk(2, dim=-1) + if not interleaved + else (out_ro[..., ::2], out_ro[..., 1::2]) + ) + rotary_emb.apply_rotary( + x1, + x2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + o1, + o2, + False, + ) + if not inplace and rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.inplace = inplace + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + inplace = ctx.inplace + do_ro = do[..., :rotary_dim] + do1, do2 = ( + do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) + ) + dx = torch.empty_like(do) if not inplace else do + if inplace: + dx1, dx2 = do1, do2 + else: + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = ( + dx_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (dx_ro[..., ::2], dx_ro[..., 1::2]) + ) + rotary_emb.apply_rotary( + do1, + do2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + dx1, + dx2, + True, + ) + if not inplace and rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None + + +apply_rotary_emb_func = ApplyRotaryEmb.apply + diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py new file mode 100644 index 0000000..4970e17 --- /dev/null +++ b/lit_gpt/lora.py @@ -0,0 +1,694 @@ +# Derived from https://github.com/microsoft/LoRA +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +r""" + Low Ranking Adaptation for LLMs scheme. + + ┌───────────────────┐ + ┆ h ┆ + └───────────────────┘ + ▲ + | + + + / \ + ┌─────────────────┐ ╭───────────────╮ Matrix initialization: + ┆ ┆ \ B / B = 0 + ┆ pretrained ┆ \ r*d / A = N(0, sigma^2) + ┆ weights ┆ ╰─────────╯ + ┆ ┆ | r | r - rank + ┆ W e R^(d*d) ┆ | ◀─────▶ | + ┆ ┆ ╭─────────╮ + └─────────────────┘ / A \ + ▲ / d*r \ + \ ╰───────────────╯ + \ ▲ + \ / + \ / + ┌───────────────────┐ + ┆ x ┆ + └───────────────────┘ + +With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, +we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates +for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of +course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen +pretrained weights and thus fine-tune the model. + +The goal of this approach is to move weight updates into a separate matrix which is decomposed with +two matrices of a lower rank. +""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing_extensions import Self + +import lit_gpt +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.model import KVCache, RoPECache +from lit_gpt.utils import map_old_state_dict_weights + + +class LoRALayer(nn.Module): + def __init__(self, r: int, lora_alpha: int, lora_dropout: float): + """Store LoRA specific attributes in a class. + + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__() + assert r >= 0 + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + + +class LoRALinear(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + """LoRA wrapper around linear class. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def reset_parameters(self): + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def merge(self): + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + # Merge the weights and mark it + self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return pretrained + lora + + +class LoRAQKVLinear(LoRALinear): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + n_head: int, + n_query_groups: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, + **kwargs, + ): + """LoRA wrapper around linear class that is used for calculation of q, k and v matrices. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + n_head: number of attention heads + n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we + don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` + and `value` but keep `key` without weight updates we should pass `[True, False, True]` + """ + super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.n_head = n_head + self.n_query_groups = n_query_groups + if isinstance(enable_lora, bool): + enable_lora = [enable_lora] * 3 + assert len(enable_lora) == 3 + self.enable_lora = enable_lora + + # Actual trainable parameters + # To better understand initialization let's imagine that we have such parameters: + # ⚬ in_features: 128 (embeddings_size) + # ⚬ out_features: 384 (3 * embedding_size) + # ⚬ r: 2 + # ⚬ enable_lora: [True, False, True] + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128) + enable_q, enable_k, enable_v = enable_lora + self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) + # qkv_shapes will be used to split a tensor with weights correctly + qkv_shapes = ( + self.linear.in_features * enable_q, + self.kv_embd_size * enable_k, + self.kv_embd_size * enable_v, + ) + self.qkv_shapes = [s for s in qkv_shapes if s] + self.lora_B = nn.Parameter(self.linear.weight.new_zeros(sum(self.qkv_shapes), r)) # (256, 2)) + # Notes about shapes above + # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; + # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in + # F.linear function weights are automatically transposed. In addition conv1d requires channels to + # be before seq length + # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is + # 128*2; 2 tells to have two channels per group for group convolution + + # Scaling: + # This balances the pretrained model`s knowledge and the new task-specific adaptation + # https://lightning.ai/pages/community/tutorial/lora-llm/ + # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set + # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can + # tune these values to your needs. This value can be even slightly greater than 1.0! + # https://github.com/cloneofsimo/lora + self.scaling = self.lora_alpha / self.r + + # Compute the indices + # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values, + # but not keys, then the weights update should be: + # + # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + # [....................................], + # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + # ↑ ↑ ↑ + # ________________________________________ + # | query | key | value | + # ---------------------------------------- + self.lora_ind = [] + if enable_q: + self.lora_ind.extend(range(0, self.linear.in_features)) + if enable_k: + self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) + if enable_v: + self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) + self.reset_parameters() + + def zero_pad(self, x: torch.Tensor) -> torch.Tensor: + """Properly pad weight updates with zeros. + + If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, + then the weights update should be: + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + [....................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + ↑ ↑ ↑ + ________________________________________ + | query | key | value | + ---------------------------------------- + + Args: + x: tensor with weights update that will be padded with zeros if necessary + + Returns: + A tensor with weight updates and zeros for deselected q, k or v + """ + # we need to do zero padding only if LoRA is disabled for one of QKV matrices + if all(self.enable_lora): + return x + + # Let's image that: + # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) + # ⚬ embeddings_size: 128 + # ⚬ self.linear.out_features: 384 (3 * embeddings_size) + # ⚬ enable_lora: [True, False, True] + # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected + # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but + # only for key updates (this is where self.lora_ind comes in handy) + # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors + # for example when we want to merge/unmerge LoRA weights and pretrained weights + x = x.transpose(0, 1) + result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) + result = result.view(-1, self.linear.out_features) # (4096, 384) + result = result.index_copy( + 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) + ) # (4096, 256) + return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) + + def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. + + If the number of heads is equal to the number of query groups - grouped queries are disabled + (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized + query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the + input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple + conv layers side by side). + + Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, + apply each part of the weight matrix to the corresponding input's part and concatenate the result. + + Args: + input: input matrix of shape (B, C, T) + weight: weight matrix of shape (C_output, rank, 1). + "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). + + Returns: + A tensor with a shape (B, C_output, T) + + """ + if self.n_head == self.n_query_groups: + return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T) + + # Notation: + # ⚬ N: number of enabled LoRA layers (self.enable_lora) + # ⚬ C_output': embeddings size for each LoRA layer (not equal in size) + # ⚬ r: rank of all LoRA layers (equal in size) + + input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) + weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) + return torch.cat( + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) + ) # (B, C_output, T) + + def merge(self): + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + + # Let's assume that: + # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + if self.r > 0 and any(self.enable_lora) and not self.merged: + delta_w = self.conv1d( + self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) + self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + # W = W + delta_W (merge) + self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Do the forward pass. + + If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. + If not, then multiply pretrained weights with input, apply LoRA on input and do summation. + + Args: + x: input tensor of shape (batch_size, context_length, embedding_size) + + Returns: + Output tensor of shape (batch_size, context_length, 3 * embedding_size) + """ + + # Let's assume that: + # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) + # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + + # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or not any(self.enable_lora) or self.merged: + return pretrained + after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) + # For F.conv1d: + # ⚬ input: input tensor of shape (mini-batch, in_channels, iW) + # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) + after_B = self.conv1d( + after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) + self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) + return pretrained + lora + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: + """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. + + Args: + model: model with LoRA layers + bias: + ``"none"``: all bias weights will be frozen, + ``"lora_only"``: only bias weight for LoRA layers will be unfrozen, + ``"all"``: all bias weights will be unfrozen. + + Raises: + NotImplementedError: if `bias` not in ["none", "lora_only", "all"] + """ + # freeze all layers except LoRA's + for n, p in model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + + # depending on the `bias` value unfreeze bias weights + if bias == "none": + return + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_filter(key: str, value: Any) -> bool: + return "lora_" in key + + +@dataclass +class Config(BaseConfig): + """ + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + to_*: either apply LoRA to the specified weights or not + """ + + r: int = 0 + alpha: int = 1 + dropout: float = 0.0 + to_query: bool = False + to_key: bool = False + to_value: bool = False + to_projection: bool = False + to_mlp: bool = False + to_head: bool = False + + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.lora, self._mlp_class) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = LoRALinear( + config.n_embd, + config.padded_vocab_size, + bias=False, + r=(config.r if config.to_head else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + + def forward( + self, + idx: torch.Tensor, + max_seq_length: Optional[int] = None, + input_pos: Optional[torch.Tensor] = None, + lm_head_chunk_size: int = 0, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + B, T = idx.size() + use_kv_cache = input_pos is not None + + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + if use_kv_cache: # not relevant otherwise + assert ( + max_seq_length >= T + ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx) # 2 * (block_size, head_size * rotary_percentage) + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size) + + cos, sin = self.rope_cache + if use_kv_cache: + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + cos = cos[:T] + sin = sin[:T] + mask = None + + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) + + if not use_kv_cache: + for block in self.transformer.h: + x, *_ = block(x, (cos, sin), max_seq_length) + else: + self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) + + x = self.transformer.ln_f(x) + + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (B, T, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, LoRALinear): + module.reset_parameters() + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = {"lm_head.weight": "lm_head.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + def __init__(self, config: Config) -> None: + """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for + parameter-efficient fine-tuning. + + *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for + query, key and value for each head) we can do this in a single pass with a single weight matrix. + """ + # Skip the parent class __init__ altogether and replace it to avoid + # useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = LoRAQKVLinear( + in_features=config.n_embd, + out_features=shape, + r=config.r, + lora_alpha=config.alpha, + lora_dropout=config.dropout, + enable_lora=(config.to_query, config.to_key, config.to_value), + bias=config.bias, + # for MQA/GQA support + n_head=config.n_head, + n_query_groups=config.n_query_groups, + ) + # output projection + self.proj = LoRALinear( + config.n_embd, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_projection else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + self.config = config + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.fc_2 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def merge_lora_weights(model: GPT) -> None: + """Merge LoRA weights into the full-rank weights to speed up inference.""" + for module in model.modules(): + if isinstance(module, LoRALinear): + module.merge() diff --git a/lit_gpt/model.py b/lit_gpt/model.py new file mode 100644 index 0000000..dff915f --- /dev/null +++ b/lit_gpt/model.py @@ -0,0 +1,345 @@ +"""Full definition of a GPT NeoX Language Model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" +import math +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +from lightning_utilities.core.imports import RequirementCache +from typing_extensions import Self +from flash_attn import flash_attn_func +from lit_gpt.config import Config +from xformers.ops import SwiGLU +from .fused_rotary_embedding import apply_rotary_emb_func +RoPECache = Tuple[torch.Tensor, torch.Tensor] +KVCache = Tuple[torch.Tensor, torch.Tensor] +FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") + + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + + def _init_weights(self, module: nn.Module, n_layer) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf + # print module name + if isinstance(module, nn.Embedding): + # RWKV: set it to 1e-4 + torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))) + # torch.nn.init.normal_(module.weight, -1e-4, 1e-4) + elif isinstance(module, nn.Linear): + # fan-in variance scaling intializer + torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + # GPT-NeoX + for name, p in module.named_parameters(): + if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU)): #if use xformer swiglu, fc2 layer will be renamed to w3 + nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(p.shape[-1]) / n_layer) + + + def reset_cache(self) -> None: + self.kv_caches.clear() + if self.mask_cache is not None and self.mask_cache.device.type == "xla": + # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 + self.rope_cache = None + self.mask_cache = None + + def forward( + self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + B, T = idx.size() + use_kv_cache = input_pos is not None + + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + if use_kv_cache: # not relevant otherwise + assert ( + max_seq_length >= T + ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx) + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) + + cos, sin = self.rope_cache + if use_kv_cache: + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + cos = cos[:T] + sin = sin[:T] + mask = None + + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if not use_kv_cache: + for block in self.transformer.h: + x, *_ = block(x, (cos, sin), max_seq_length) + else: + self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) + + x = self.transformer.ln_f(x) + + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: + return build_rope_cache( + seq_len=self.config.block_size, + n_elem=int(self.config.rotary_percentage * self.config.head_size), + dtype=torch.bfloat16, + device=idx.device, + condense_ratio=self.config.condense_ratio, + ) + + def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: + ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: + B = idx.size(0) + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + k_cache_shape = ( + B, + heads, + max_seq_length, + rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), + ) + v_cache_shape = (B, heads, max_seq_length, self.config.head_size) + device = idx.device + return [ + (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) + for _ in range(self.config.n_layer) + ] + + +class Block(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + self.config = config + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + n_1 = self.norm_1(x) + h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = x + h + self.mlp(n_2) + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + + x = x + h + x = x + self.mlp(self.norm_2(x)) + return x, new_kv_cache + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + + self.config = config + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs) + # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) + + # repeat k and v if necessary + # Peiyuan: we do not need to do this as flash attention 2 already support GQA + # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) + # # for MHA this is a no-op + # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) + k = k.reshape(B, T, -1, self.config.head_size) + v = v.reshape(B, T, -1, self.config.head_size) + + cos, sin = rope + + # apply rope in fp32 significanly stabalize training + # fused rope expect (batch_size, seqlen, nheads, headdim) + q = apply_rotary_emb_func(q, cos, sin, False, True) + k = apply_rotary_emb_func(k, cos, sin, False, True) + + # n_elem = int(self.config.rotary_percentage * self.config.head_size) + + # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) + # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) + # print( (q_roped - q).sum()) + # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=2) + cache_v = torch.roll(cache_v, -1, dims=2) + k = cache_k.index_copy_(2, input_pos, k) + v = cache_v.index_copy_(2, input_pos, v) + kv_cache = k, v + + y = self.scaled_dot_product_attention(q, k, v, mask=mask) + + y = y.reshape(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.proj(y) + + return y, kv_cache + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ): + scale = 1.0 / math.sqrt(self.config.head_size) + if ( + FlashAttention2Available + and mask is None + and q.device.type == "cuda" + and q.dtype in (torch.float16, torch.bfloat16) + ): + from flash_attn import flash_attn_func + + return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True) + assert False + return y.transpose(1, 2) + + +class GptNeoxMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.gelu(x) + return self.proj(x) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x_fc_1 = self.fc_1(x) + # x_fc_2 = self.fc_2(x) + # x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + # return self.proj(x) + return self.swiglu(x) + + +def build_rope_cache( + seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 +) -> RoPECache: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta) + + cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) + + # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding + if dtype == torch.bfloat16: + return cos.bfloat16(), sin.bfloat16() + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + return cos.half(), sin.half() + return cos, sin + + +def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.type_as(x) diff --git a/lit_gpt/packed_dataset.py b/lit_gpt/packed_dataset.py new file mode 100644 index 0000000..1b4b7dc --- /dev/null +++ b/lit_gpt/packed_dataset.py @@ -0,0 +1,235 @@ +# Very loosely inspired by indexed_dataset in Fairseq, Megatron +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py + + +import os +import random +import struct + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} + + +def code(dtype): + for k in dtypes: + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +HDR_MAGIC = b"LITPKDS" +HDR_SIZE = 24 # bytes + + +class PackedDataset(IterableDataset): + def __init__( + self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 + ): + self._filenames = filenames + self._n_chunks = n_chunks + self._block_size = block_size + self._seed = seed + self._shuffle = shuffle + self._wrap = wrap + self._num_processes = num_processes + self._process_rank = process_rank + + def __iter__(self): + worker_info = get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + num_shards = num_workers * self._num_processes + shard_id = self._process_rank * num_workers + worker_id + + max_num_files = len(self._filenames) // num_shards * num_shards + filenames = self._filenames[shard_id:max_num_files:num_shards] + + return PackedDatasetIterator( + filenames=filenames, + n_chunks=self._n_chunks, + block_size=self._block_size, + seed=self._seed, + shuffle=self._shuffle, + wrap=self._wrap, + ) + + +class PackedDatasetBuilder(object): + def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): + if dtype == "auto": + if vocab_size is None: + raise ValueError("vocab_size cannot be None when dtype='auto'") + if vocab_size is not None and vocab_size < 65500: + self._dtype = np.uint16 + else: + self._dtype = np.int32 + else: + self._dtype = dtype + self._counter = 0 + self._chunk_size = chunk_size + self._outdir = outdir + self._prefix = prefix + self._sep_token = sep_token + self._arr = np.zeros(self._chunk_size, dtype=self._dtype) + self._arr.fill(self._sep_token) + self._idx = 0 + self._version = 1 + self._filenames = [] + + def _write_chunk(self): + filename = f"{self._prefix}_{self._counter:010d}.bin" + filename = os.path.join(self._outdir, filename) + + with open(filename, "wb") as f: + f.write(HDR_MAGIC) + f.write(struct.pack(" self._chunk_size: + part_len = self._chunk_size - self._idx + self._arr[self._idx : self._idx + part_len] = arr[:part_len] + self._write_chunk() + arr = arr[part_len:] + + arr_len = arr.shape[0] + self._arr[self._idx : self._idx + arr_len] = arr + self._idx += arr_len + + def write_reminder(self): + self._write_chunk() + + +class PackedDatasetIterator: + def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): + self._seed = seed + self._shuffle = shuffle + self._rng = np.random.default_rng(seed) if shuffle else None + self._block_idxs = None + + self._wrap = wrap + + # TODO: instead of filenames, we could have a single text stream + # (or text file) with the sequence of all files to be + # fetched/loaded. + self._filenames = filenames + self._file_idx = 0 + + self._n_chunks = n_chunks + + self._dtype = None + self._block_size = block_size + self._n_blocks = None + + self._mmaps = [] + self._buffers = [] + + self._block_idxs = [] + self._curr_idx = 0 + + self._load_n_chunks() + + def _read_header(self, path): + with open(path, "rb") as f: + magic = f.read(len(HDR_MAGIC)) + assert magic == HDR_MAGIC, "File doesn't match expected format." + version = struct.unpack(" len(self._filenames[self._file_idx :]): + # if not self._wrap: + # raise StopIteration + self._file_idx = 0 + + for i in range(self._n_chunks): + filename = self._filenames[self._file_idx + i] + if self._dtype is None: + self._dtype, self._chunk_size = self._read_header(filename) + self._n_blocks = self._chunk_size // self._block_size + # TODO: check header matches with previous files + mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) + self._mmaps.append(mmap) + self._buffers.append(memoryview(mmap)) + + self._file_idx += self._n_chunks + n_all_blocks = self._n_chunks * self._n_blocks + + self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) + + self._curr_idx = 0 + + def __del__(self): + self._close_mmaps() + del self._mmaps + del self._buffers + + def __iter__(self): + return self + + def __next__(self): + if self._curr_idx >= len(self._block_idxs): + self._load_n_chunks() + # TODO: trigger fetching next next n_chunks if remote + block_idx = self._block_idxs[self._curr_idx] + chunk_id = block_idx // self._n_blocks + buffer = self._buffers[chunk_id] + elem_id = (block_idx % self._n_blocks) * self._block_size + offset = np.dtype(self._dtype).itemsize * elem_id + arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + self._curr_idx += 1 + return torch.from_numpy(arr.astype(np.int64)) + + +class CombinedDataset(IterableDataset): + def __init__(self, datasets, seed, weights=None): + self._seed = seed + self._datasets = datasets + self._weights = weights + n_datasets = len(datasets) + if weights is None: + self._weights = [1 / n_datasets] * n_datasets + + def __iter__(self): + return CombinedDatasetIterator(self._datasets, self._seed, self._weights) + + +class CombinedDatasetIterator: + def __init__(self, datasets, seed, weights): + self._datasets = [iter(el) for el in datasets] + self._weights = weights + self._rng = random.Random(seed) + + def __next__(self): + (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) + return next(dataset) diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py new file mode 100644 index 0000000..1c7362a --- /dev/null +++ b/lit_gpt/rmsnorm.py @@ -0,0 +1,842 @@ +import torch +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16 + +import dropout_layer_norm +import torch +from torch.nn import init + + +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes""" + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + +def _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + rowscale, + colscale, + None, + None, + dropout_p, + epsilon, + 1.0, + 0, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + None, + None, + dropout_p, + 1.0, + 0, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma0.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( + x0mat, + x1mat, + residualmat, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask0 and dmask1 are None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma + + +def _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + """ + hidden_size = gamma0.numel() + xmat = x.view((-1, hidden_size)) + dz0mat = dz0.view(xmat.shape) + dz1mat = dz1.view(xmat.shape) if dz1 is not None else None + dxmat = dx.view(xmat.shape) if dx is not None else None + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + *rest, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( + dz0mat, + dz1mat, + dxmat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 + + +class DropoutAddLayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward( + xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + if not return_dmask: + return ( + zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) + ) + else: + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return ( + (zmat.view(x0.shape), dmask) + if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) + ) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + None, + dcolscale, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward( + xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) + else: + z = zmat.view(z_shape) + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + ctx.rowscale_const, + ctx.x0_numrows, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + dcolscale, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_x1 = x1 is not None + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta0 is not None + z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) + if not return_dmask: + return z if not prenorm else (*z, xmat.view(x0.shape)) + else: + dmask0 = ( + dmask0.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + dmask1 = ( + dmask1.view(x0.shape) + if dropout_p > 0.0 and x1 is not None + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask0) + ctx.mark_non_differentiable(dmask1) + return ( + (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) + ) + + @staticmethod + def backward(ctx, dz0, dz1, *args): + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_x1 = ctx.has_x1 + has_residual = ctx.has_residual + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + ) = _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + return ( + dx0, + dx1, + dresidual, + dgamma0, + dbeta0 if ctx.has_beta else None, + dgamma1, + dbeta1 if ctx.has_beta else None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm(x, weight, bias, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) + + +def dropout_add_layer_norm( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + rowscale=None, + layerscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, + residual, + weight, + bias, + rowscale, + layerscale, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_subset( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + layerscale=None, + x0_subset=None, + out_subset=None, + rowscale_const=1.0, + out_numrows=0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, + residual, + weight, + bias, + layerscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_parallel_residual( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + prenorm=False, + p=0.0, + eps=1e-5, + residual_in_fp32=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, residual=None): + return dropout_add_layer_norm( + x0, + residual, + self.weight, + self.bias, + self.p if self.training else 0.0, + self.eps, + prenorm=self.prenorm, + residual_in_fp32=self.residual_in_fp32, + ) + +def rms_norm(x, weight, epsilon): + return DropoutAddLayerNormFn.apply( + x, None, weight, None, None, None, 0.0, epsilon, False, False, True + ) +class FusedRMSNorm(torch.nn.Module): + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(size)) + self.dim = dim + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return self.weight * x_normed + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) diff --git a/lit_gpt/speed_monitor.py b/lit_gpt/speed_monitor.py new file mode 100644 index 0000000..fa81b18 --- /dev/null +++ b/lit_gpt/speed_monitor.py @@ -0,0 +1,408 @@ +import time +from collections import deque +from contextlib import nullcontext +from typing import Any, Callable, Deque, Dict, Optional + +import torch +from lightning import Callback, Fabric, LightningModule, Trainer +from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only +from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only +from torch.utils.flop_counter import FlopCounterMode +import math +from lit_gpt import GPT, Config +from lit_gpt.utils import num_parameters + +GPU_AVAILABLE_FLOPS = { + # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + # nvidia publishes spec sheet with a 2x sparsity factor + "h100-sxm": { + "64-true": 67e12, + "32-true": 67e12, + "16-true": 1.979e15 / 2, + "16-mixed": 1.979e15 / 2, + "bf16-true": 1.979e15 / 2, + "bf16-mixed": 1.979e15 / 2, + "8-true": 3.958e15 / 2, + "8-mixed": 3.958e15 / 2, + }, + "h100-pcie": { + "64-true": 51e12, + "32-true": 51e12, + "16-true": 1.513e15 / 2, + "16-mixed": 1.513e15 / 2, + "bf16-true": 1.513e15 / 2, + "bf16-mixed": 1.513e15 / 2, + "8-true": 3.026e15 / 2, + "8-mixed": 3.026e15 / 2, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + "a100": { + "64-true": 19.5e12, + "32-true": 19.5e12, + "16-true": 312e12, + "16-mixed": 312e12, + "bf16-true": 312e12, + "bf16-mixed": 312e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf + "a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12}, + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12}, + "v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12}, + "v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12}, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + "t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12}, + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf + "quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12}, +} + +TPU_AVAILABLE_FLOPS = { + # flop count for each TPU generation is the same for all precisions + # since bfloat16 precision is always used for performing matrix operations + # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16 + # source: https://arxiv.org/pdf/1907.10701.pdf + "v2": 45e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3 + "v3": 123e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4 + "v4": 275e12, +} + + +def get_flops_available(device: torch.device, precision: str) -> Optional[float]: + if device.type == "cuda": + device_name = torch.cuda.get_device_name(device).lower() + if "h100" in device_name and "hbm3" in device_name: + device_name = "h100-sxm" + elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): + device_name = "h100-pcie" + elif "a100" in device_name: + device_name = "a100" + elif "a10g" in device_name: + device_name = "a10g" + elif "v100-sxm" in device_name: + device_name = "v100-sxm" + elif "v100-pcie" in device_name: + device_name = "v100-pcie" + elif "t4" in device_name: + device_name = "t4" + elif "quadro rtx 5000" in device_name: + device_name = "quadro rtx 5000" + else: + device_name = None + + if device_name is not None: + try: + return int(GPU_AVAILABLE_FLOPS[device_name][precision]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with precision: {precision}; " + "MFU cannot be calculated and reported." + ) + elif device.type == "xla": + from torch_xla.experimental import tpu + + device_name = tpu.get_tpu_env()["TYPE"].lower() + try: + return int(TPU_AVAILABLE_FLOPS[device_name]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with precision: {precision}; " + "MFU cannot be calculated and reported." + ) + + return None + + +# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py + + +class SpeedMonitorBase: + """Logs the training throughput and utilization. + + +-------------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +=====================================+===========================================================+ + | | Rolling average (over `window_size` most recent | + | `throughput/batches_per_sec` | batches) of the number of batches processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/samples_per_sec` | batches) of the number of samples processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | + | | This may include padding depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | Estimates flops by `flops_per_batch * batches_per_sec` | + | `throughput/flops_per_sec` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/tokens_per_sec` divided by world size. This | + | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/flops_per_sec` divided by world size. Only | + | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/device/flops_per_sec` divided by world size. | + | `throughput/device/mfu` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `time/train` | Total elapsed training time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/val` | Total elapsed validation time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/total` | Total elapsed time (time/train + time/val) | + +-------------------------------------+-----------------------------------------------------------+ + + Notes: + - The implementation assumes that devices are homogeneous as it normalizes by the world size. + - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or + batches/sec to measure throughput under this circumstance. + - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. + There is no widespread, realistic, and reliable implementation to compute them. + We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which + will almost always be an overestimate when compared to the true value. + + Args: + window_size (int, optional): Number of batches to use for a rolling average of throughput. + Defaults to 100. + time_unit (str, optional): Time unit to use for `time` logging. Can be one of + 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. + """ + + def __init__( + self, + flops_available: float, + log_dict: Callable[[Dict, int], None], + window_size: int = 100, + time_unit: str = "hours", + log_iter_interval: int = 1, + ): + self.flops_available = flops_available + self.log_dict = log_dict + self.log_iter_interval = log_iter_interval + # Track the batch num samples and wct to compute throughput over a window of batches + self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval) + self.history_wct: Deque[float] = deque(maxlen=window_size + 1) + self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) + self.history_flops: Deque[int] = deque(maxlen=window_size + 1) + + self.divider = 1 + if time_unit == "seconds": + self.divider = 1 + elif time_unit == "minutes": + self.divider = 60 + elif time_unit == "hours": + self.divider = 60 * 60 + elif time_unit == "days": + self.divider = 60 * 60 * 24 + else: + raise ValueError( + f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' + ) + + # Keep track of time spent evaluating + self.total_eval_wct = 0.0 + self.iter = -1 + + def on_train_batch_end( + self, + samples: int, # total samples seen (per device) + train_elapsed: float, # total training time (seconds) + world_size: int, + flops_per_batch: Optional[int] = None, # (per device) + lengths: Optional[int] = None, # total length of the samples seen (per device) + train_loss: Optional[float] = None, + ): + self.iter += 1 + metrics = {} + + self.history_samples.append(samples) + self.history_training_loss.append(train_loss) + if lengths is not None: + self.history_lengths.append(lengths) + # if lengths are passed, there should be as many values as samples + assert len(self.history_samples) == len(self.history_lengths) + self.history_wct.append(train_elapsed) + if len(self.history_wct) == self.history_wct.maxlen: + elapsed_batches = len(self.history_samples) - 1 + elapsed_samples = self.history_samples[-1] - self.history_samples[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + samples_per_sec = elapsed_samples * world_size / elapsed_wct + dev_samples_per_sec = elapsed_samples / elapsed_wct + metrics.update( + { + "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, + "throughput/samples_per_sec": samples_per_sec, + "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, + "throughput/device/samples_per_sec": dev_samples_per_sec, + } + ) + if lengths is not None: + elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) + avg_length = elapsed_lengths / elapsed_batches + metrics.update( + { + "throughput/tokens_per_sec": samples_per_sec * avg_length, + "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, + "total_tokens": avg_length * world_size * samples, + } + ) + if train_loss is not None: + avg_loss = sum(self.history_training_loss) / len(self.history_training_loss) + metrics.update( + { + "metric/train_loss": avg_loss, + "metric/train_ppl": math.exp(avg_loss) + } + ) + + if flops_per_batch is not None: + # sum of flops per batch across ranks + self.history_flops.append(flops_per_batch * world_size) + if len(self.history_flops) == self.history_flops.maxlen: + elapsed_flops = sum(self.history_flops) - self.history_flops[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + flops_per_sec = elapsed_flops / elapsed_wct + device_flops_per_sec = flops_per_sec / world_size + metrics.update( + {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} + ) + if self.flops_available: + metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available + + metrics.update( + { + "time/train": train_elapsed / self.divider, + "time/val": self.total_eval_wct / self.divider, + "time/total": (train_elapsed + self.total_eval_wct) / self.divider, + "samples": samples, + } + ) + if self.iter % self.log_iter_interval == 0: + self.log_dict(metrics, self.iter//self.log_iter_interval) + + def eval_end(self, eval_elapsed: float): + self.total_eval_wct += eval_elapsed # seconds + + +class SpeedMonitorFabric(SpeedMonitorBase): + def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: + # TODO: this will not work properly if a precision plugin is passed to Fabric + flops_available = get_flops_available(fabric.device, fabric._connector._precision_input) + super().__init__(flops_available, fabric.log_dict, *args, **kwargs) + + @fabric_rank_zero_only + def on_train_batch_end(self, *args: Any, **kwargs: Any): + super().on_train_batch_end(*args, **kwargs) + + +class SpeedMonitorCallback(Callback): + def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: + super().__init__() + self.speed_monitor: Optional[SpeedMonitorBase] = None + self.speed_monitor_kwargs = kwargs + self.length_fn = length_fn + self.batch_size = batch_size + self.eval_t0: int = 0 + self.train_t0: int = 0 + self.total_lengths: int = 0 + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + if self.speed_monitor is not None: + return # already setup + # TODO: this will not work properly if a precision plugin is passed to Trainer + flops_available = get_flops_available( + trainer.strategy.root_device, trainer._accelerator_connector._precision_flag + ) + self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) + + @trainer_rank_zero_only + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if trainer.fit_loop._should_accumulate(): + return + + self.train_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_train_batch_end( + self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int + ) -> None: + self.total_lengths += self.length_fn(batch) + if trainer.fit_loop._should_accumulate(): + return + train_elapsed = time.perf_counter() - self.train_t0 + assert self.speed_monitor is not None + iter_num = trainer.fit_loop.total_batch_idx + assert (measured_flops := pl_module.measured_flops) is not None + self.speed_monitor.on_train_batch_end( + (iter_num + 1) * self.batch_size, + train_elapsed, + # this assumes that device FLOPs are the same and that all devices have the same batch size + trainer.world_size, + flops_per_batch=measured_flops, + lengths=self.total_lengths, + ) + + @trainer_rank_zero_only + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.eval_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + eval_elapsed = time.perf_counter() - self.eval_t0 + assert self.speed_monitor is not None + self.speed_monitor.eval_end(eval_elapsed) + + +def flops_per_param(config: Config, n_params: int) -> int: + flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation + # this assumes that all samples have a fixed length equal to the block size + # which is most likely false during finetuning + flops_per_seq = flops_per_token * config.block_size + attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2)) + return flops_per_seq + attn_flops_per_seq + + +def estimate_flops(model: GPT) -> int: + """Measures estimated FLOPs for MFU. + + Refs: + * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 + * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 + """ + # using all parameters for this is a naive over estimation because not all model parameters actually contribute to + # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage + # (~10%) compared to the measured FLOPs, making those lower but more realistic. + # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. + n_trainable_params = num_parameters(model, requires_grad=True) + trainable_flops = flops_per_param(model.config, n_trainable_params) + # forward + backward + gradients (assumes no gradient accumulation) + ops_per_step = 3 if model.training else 1 + n_frozen_params = num_parameters(model, requires_grad=False) + frozen_flops = flops_per_param(model.config, n_frozen_params) + # forward + backward + frozen_ops_per_step = 2 if model.training else 1 + return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops + + +def measure_flops(model: GPT, x: torch.Tensor) -> int: + """Measures real FLOPs for HFU""" + flop_counter = FlopCounterMode(model, display=False) + ctx = nullcontext() if model.training else torch.no_grad() + with ctx, flop_counter: + y = model(x) + if model.training: + y.sum().backward() + return flop_counter.get_total_flops() diff --git a/lit_gpt/tokenizer.py b/lit_gpt/tokenizer.py new file mode 100644 index 0000000..a076c13 --- /dev/null +++ b/lit_gpt/tokenizer.py @@ -0,0 +1,77 @@ +import json +from pathlib import Path +from typing import Optional + +import torch + + +class Tokenizer: + def __init__(self, checkpoint_dir: Path) -> None: + # some checkpoints have both files, `.model` takes precedence + if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): + from sentencepiece import SentencePieceProcessor + + self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) + self.backend = "sentencepiece" + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): + from tokenizers import Tokenizer as HFTokenizer + + self.processor = HFTokenizer.from_file(str(vocabulary_path)) + self.backend = "huggingface" + with open(checkpoint_dir / "tokenizer_config.json") as fp: + config = json.load(fp) + bos_token = config.get("bos_token") + self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None + self.eos_id = self.token_to_id(config["eos_token"]) + else: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + if self.backend == "huggingface": + return self.processor.get_vocab_size(with_added_tokens=False) + if self.backend == "sentencepiece": + return self.processor.vocab_size() + raise RuntimeError + + def token_to_id(self, token: str) -> int: + if self.backend == "huggingface": + id_ = self.processor.token_to_id(token) + elif self.backend == "sentencepiece": + id_ = self.processor.piece_to_id(token) + else: + raise RuntimeError + if id_ is None: + raise ValueError(f"token {token!r} not found in the collection.") + return id_ + + def encode( + self, + string: str, + device: Optional[torch.device] = None, + bos: bool = False, + eos: bool = False, + max_length: int = -1, + ) -> torch.Tensor: + if self.backend == "huggingface": + tokens = self.processor.encode(string).ids + elif self.backend == "sentencepiece": + tokens = self.processor.encode(string) + else: + raise RuntimeError + if bos: + bos_id = self.bos_id + if bos_id is None: + raise NotImplementedError("This tokenizer does not defined a bos token") + tokens = [bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tensor: torch.Tensor) -> str: + tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() + return self.processor.decode(tokens) diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py new file mode 100644 index 0000000..d1d7bc6 --- /dev/null +++ b/lit_gpt/utils.py @@ -0,0 +1,505 @@ +"""Utility functions for training and inference.""" + +import pickle +import sys +import warnings +from contextlib import contextmanager +from functools import partial +from io import BytesIO +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union + +import torch +import torch.nn as nn +import torch.utils._device +from lightning.fabric.loggers import CSVLogger +from torch.serialization import normalize_storage_type + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: + return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) + + +@contextmanager +def quantization(mode: Optional[str] = None): + if mode is None: + yield + return + + if mode == "bnb.int8": + from quantize.bnb import InferenceLinear8bitLt + + quantized_linear_cls = InferenceLinear8bitLt + elif mode == "bnb.fp4": + from quantize.bnb import Linear4bit + + # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.fp4-dq": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.nf4": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.nf4-dq": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "gptq.int4": + from quantize.gptq import ColBlockQuantizedLinear + + class QuantizedLinear(ColBlockQuantizedLinear): + def __init__(self, *args, **kwargs): + super().__init__(*args, bits=4, tile_cols=-1, **kwargs) + + quantized_linear_cls = QuantizedLinear + else: + raise ValueError(f"Unknown quantization mode: {mode}") + + torch_linear_cls = torch.nn.Linear + torch.nn.Linear = quantized_linear_cls + yield + torch.nn.Linear = torch_linear_cls + + +# this is taken from torchhacks https://github.com/lernapparat/torchhacks + + +class NotYetLoadedTensor: + def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + self.metatensor = metatensor + self.archiveinfo = archiveinfo + self.storageinfo = storageinfo + self.rebuild_args = rebuild_args + + @classmethod + def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): + ret = func(*args) + if isinstance(ret, NotYetLoadedTensor): + old_lt = ret._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) + + ret._load_tensor = _load_tensor + return ret + return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) + + @classmethod + def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): + if isinstance(data, NotYetLoadedTensor): + old_lt = data._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) + + data._load_tensor = _load_tensor + return data + return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) + + @classmethod + def rebuild_tensor_v2( + cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None + ): + rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) + metatensor = torch._utils._rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata + ) + storageinfo = storage.archiveinfo + return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) + + def _load_tensor(self): + name, storage_cls, fn, device, size = self.storageinfo + dtype = self.metatensor.dtype + + uts = ( + self.archiveinfo.zipfile_context.zf.get_storage_from_record( + f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage + ) + ._typed_storage() + ._untyped_storage + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) + return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] + return func(*loaded_args, **kwargs) + # gc.collect would be costly here, maybe do it optionally + + def __getattr__(self, name): + # properties + ## TODO: device, is_...?? + ## TODO: mH, mT, H, T, data, imag, real + ## name ??? + if name in { + "dtype", + "grad", + "grad_fn", + "layout", + "names", + "ndim", + "output_nr", + "requires_grad", + "retains_grad", + "shape", + "volatile", + }: + return getattr(self.metatensor, name) + if name in {"size"}: + return getattr(self.metatensor, name) + # materializing with contiguous is needed for quantization + if name in {"contiguous"}: + return getattr(self._load_tensor(), name) + + raise AttributeError(f"{type(self)} does not have {name}") + + def __repr__(self): + return f"NotYetLoadedTensor({repr(self.metatensor)})" + + +class LazyLoadingUnpickler(pickle.Unpickler): + def __init__(self, file, zipfile_context): + super().__init__(file) + self.zipfile_context = zipfile_context + + def find_class(self, module, name): + res = super().find_class(module, name) + if module == "torch._utils" and name == "_rebuild_tensor_v2": + return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) + if module == "torch._tensor" and name == "_rebuild_from_type_v2": + return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) + if module == "torch._utils" and name == "_rebuild_parameter": + return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) + return res + + def persistent_load(self, pid): + name, cls, fn, device, size = pid + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") + s.archiveinfo = pid + return s + + +class lazy_load: + def __init__(self, fn): + self.zf = torch._C.PyTorchFileReader(str(fn)) + with BytesIO(self.zf.get_record("data.pkl")) as pkl: + mup = LazyLoadingUnpickler(pkl, self) + self.sd = mup.load() + + def __enter__(self): + return self.sd + + def __exit__(self, exc_type, exc_val, exc_tb): + del self.zf # I don't think there is a way to force closing... + self.zf = None + + +def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: + files = { + "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( + checkpoint_dir / "tokenizer.model" + ).is_file(), + "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), + } + if checkpoint_dir.is_dir(): + if all(files.values()): + # we're good + return + problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" + else: + problem = " is not a checkpoint directory" + + # list locally available checkpoints + available = list(Path("checkpoints").glob("*/*")) + if available: + options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) + extra = f"\nYou have downloaded locally:{options}\n" + else: + extra = "" + + error_message = ( + f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." + "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" + f"{extra}\nSee all download options by running:\n python scripts/download.py" + ) + print(error_message, file=sys.stderr) + raise SystemExit(1) + + +class SavingProxyForStorage: + def __init__(self, obj, saver, protocol_version=5): + self.protocol_version = protocol_version + self.saver = saver + if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): + raise TypeError(f"expected storage, not {type(obj)}") + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + if isinstance(obj, torch.storage.TypedStorage): + # PT upstream wants to deprecate this eventually... + storage = obj._untyped_storage + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + else: + storage = obj + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + storage_key = saver._write_storage_and_return_key(storage) + location = torch.serialization.location_tag(storage) + + self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) + + def __reduce_ex__(self, protocol_version): + assert False, "this should be handled with out of band" + + +class SavingProxyForTensor: + def __init__(self, tensor, saver, protocol_version=5): + self.protocol_version = protocol_version + self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) + assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" + storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) + self.reduce_args = (storage_proxy, *other_reduce_args) + + def __reduce_ex__(self, protocol_version): + if protocol_version != self.protocol_version: + raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") + return self.reduce_ret_fn, self.reduce_args + + +class IncrementalPyTorchPickler(pickle.Pickler): + def __init__(self, saver, *args, **kwargs): + super().__init__(*args, **kwargs) + self.storage_dtypes = {} + self.saver = saver + self.id_map = {} + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + def persistent_id(self, obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, SavingProxyForStorage): + return obj.storage_info + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in self.storage_dtypes: + if storage_dtype != self.storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + self.storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = self.id_map.get(storage._cdata) + if storage_key is None: + storage_key = self.saver._write_storage_and_return_key(storage) + self.id_map[storage._cdata] = storage_key + location = torch.serialization.location_tag(storage) + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + +class incremental_save: + def __init__(self, name): + self.name = name + self.zipfile = torch._C.PyTorchFileWriter(str(name)) + self.has_saved = False + self.next_key = 0 + + def __enter__(self): + return self + + def store_early(self, tensor): + if isinstance(tensor, torch.Tensor): + return SavingProxyForTensor(tensor, self) + raise TypeError(f"can only store tensors early, not {type(tensor)}") + + def save(self, obj): + if self.has_saved: + raise RuntimeError("have already saved") + # Write the pickle data for `obj` + data_buf = BytesIO() + pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) + pickler.dump(obj) + data_value = data_buf.getvalue() + self.zipfile.write_record("data.pkl", data_value, len(data_value)) + self.has_saved = True + + def _write_storage_and_return_key(self, storage): + if self.has_saved: + raise RuntimeError("have already saved") + key = self.next_key + self.next_key += 1 + name = f"data/{key}" + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zipfile.write_record(name, storage.data_ptr(), num_bytes) + return key + + def __exit__(self, type, value, traceback): + self.zipfile.write_end_of_file() + + +T = TypeVar("T") + + +def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: + logger = cls(*args, **kwargs) + + def merge_by(dicts, key): + from collections import defaultdict + + out = defaultdict(dict) + for d in dicts: + if key in d: + out[d[key]].update(d) + return [v for _, v in sorted(out.items())] + + def save(self) -> None: + """Overridden to merge CSV by the step number.""" + import csv + + if not self.metrics: + return + metrics = merge_by(self.metrics, "step") + keys = sorted({k for m in metrics for k in m}) + with self._fs.open(self.metrics_file_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(metrics) + + logger.experiment.save = MethodType(save, logger.experiment) + + return logger + + +def chunked_cross_entropy( + logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 +) -> torch.Tensor: + # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate + # the memory usage in fine-tuning settings with low number of parameters. + # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing + # the memory spike's magnitude + + # lm_head was chunked (we are fine-tuning) + if isinstance(logits, list): + # don't want to chunk cross entropy + if chunk_size == 0: + logits = torch.cat(logits, dim=1) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # chunk cross entropy + logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] + target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + # no chunking at all + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + if chunk_size == 0: + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # lm_head wasn't chunked, chunk cross entropy + logit_chunks = logits.split(chunk_size) + target_chunks = targets.split(chunk_size) + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + +def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: + for checkpoint_name, attribute_name in mapping.items(): + full_checkpoint_name = prefix + checkpoint_name + if full_checkpoint_name in state_dict: + full_attribute_name = prefix + attribute_name + state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) + return state_dict + + +def get_default_supported_precision(training: bool, tpu: bool = False) -> str: + """Return default precision that is supported by the hardware. + + Args: + training: `-mixed` or `-true` version of the precision to use + tpu: whether TPU device is used + + Returns: + default precision that is suitable for the task and is supported by the hardware + """ + if tpu: + return "32-true" + if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): + return "bf16-mixed" if training else "bf16-true" + return "16-mixed" if training else "16-true" diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py new file mode 100644 index 0000000..f01ef1c --- /dev/null +++ b/pretrain/tinyllama.py @@ -0,0 +1,395 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "tiny_LLaMA_1b" +name = "tinyllama_1b" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-4 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger() + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=(3407 + fabric.global_rank), + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for train_data in train_dataloader: + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + + + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = glob.glob(str(data_dir / f"{prefix}*")) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cf1edab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# torch>=2.1.0dev +lightning @ git+https://github.com/Lightning-AI/lightning@master +jsonargparse[signatures] # CLI +pandas +pyarrow +tokenizers +sentencepiece +wandb +zstd +# other optional dependencies are +# sentencepiece # pythia, falcon, redpajama +# tokenizers # llama-based models +# bitsandbytes>=0.41.1 # quantize/bnb.py +# scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released +# datasets # quantize/gptq.py +# zstandard # scripts/prepare_redpajama.py +# git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py new file mode 100644 index 0000000..09b6ed8 --- /dev/null +++ b/scripts/convert_lit_checkpoint.py @@ -0,0 +1,264 @@ +import contextlib +import gc +import sys +from functools import partial +from pathlib import Path +from typing import Dict, Literal, Optional, Tuple, Union + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load +# from scripts.convert_hf_checkpoint import layer_template, load_param + + +def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: + split = layer_name.split(".") + number = int(split[idx]) + split[idx] = "{}" + from_name = ".".join(split) + return from_name, number + + +def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: + if hasattr(param, "_load_tensor"): + # support tensors loaded via `lazy_load()` + print(f"Loading {name!r} into RAM") + param = param._load_tensor() + if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: + print(f"Converting {name!r} from {param.dtype} to {dtype}") + param = param.to(dtype) + return param +def copy_weights_falcon( + size: Literal["7b", "40b"], + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +): + weight_map = { + "transformer.wte.weight": "transformer.word_embeddings.weight", + "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", + "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if size == "7b": + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", + } + ) + elif size == "40b": + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", + "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", + "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "gpt_neox.embed_in.weight", + "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", + "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", + "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", + "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", + "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", + "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", + "lm_head.weight": "embed_out.weight", + } + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_llama( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +): + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", + "transformer.ln_f.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", + } + for name, param in lit_weights.items(): + if name.endswith(".attn.attn.weight"): + from_name, number = layer_template(name, 2) + q = "model.layers.{}.self_attn.q_proj.weight".format(number) + k = "model.layers.{}.self_attn.k_proj.weight".format(number) + v = "model.layers.{}.self_attn.v_proj.weight".format(number) + qkv = load_param(param, name, None) + qp, kp, vp = tensor_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + elif "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name] + + if to_name is None: + continue + to_name = to_name.format(number) + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def tensor_split( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def kstart(start, blen, klen) -> int: + """returns start index of keys in batch""" + return start + (blen - (klen * 2)) + + def vstart(start, blen, klen) -> int: + """returns start index of values in batch""" + return start + blen - klen + + def vend(start, blen) -> int: + """returns last index of values in batch""" + return start + blen + + # num observations + nobs = param.shape[0] + # batch length + blen = nobs // config.n_query_groups + # key length in batch + klen = config.head_size + # value length in batch + vlen = config.head_size + # the starting index of each new batch + starts = range(0, nobs, blen) + # the indices to splice on + splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] + + qc = () + kc = () + vc = () + + for splice in splices: + qs, ks, vs, ve = splice + qc += (param[qs:ks, :],) + kc += (param[ks:vs, :],) + vc += (param[vs:ve, :],) + + q = torch.cat(qc) + k = torch.cat(kc) + v = torch.cat(vc) + + return q, k, v + + +def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return lit_weights.get("model", lit_weights) + + +def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: + weight_names = {wk.split(".")[-1] for wk in lit_weights} + # LoRA or QLoRA + if any("lora" in wn for wn in weight_names): + raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") + # adapter v2. adapter_bias will only be in adapter_v2 + elif "adapter_bias" in weight_names: + raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") + # adapter. gating_factor is in adapter and adapter_v2 + elif "gating_factor" in weight_names: + raise NotImplementedError("Converting models finetuned with adapter not yet supported.") + + +@torch.inference_mode() +def convert_lit_checkpoint(*, checkpoint_name: str, out_dir: Path, model_name: str) -> None: + config = Config.from_name(model_name) + + if "falcon" in model_name: + copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") + elif config._mlp_class == "LLaMAMLP": + copy_fn = partial(copy_weights_llama, config) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + + # checkpoint_name cannot be hardcoded because there exists different outputs such as + # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") + pth_file = out_dir / checkpoint_name + bin_file = pth_file.with_suffix(".bin") + + with incremental_save(bin_file) as saver: + with contextlib.ExitStack() as stack: + lit_weights = stack.enter_context(lazy_load(pth_file)) + lit_weights = maybe_unwrap_state_dict(lit_weights) + check_conversion_supported(lit_weights) + # Incremental save will trigger error + copy_fn(sd, lit_weights, saver=None) + gc.collect() + saver.save(sd) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(convert_lit_checkpoint, as_positional=False) diff --git a/scripts/prepare_redpajama.py b/scripts/prepare_redpajama.py new file mode 100644 index 0000000..2b56726 --- /dev/null +++ b/scripts/prepare_redpajama.py @@ -0,0 +1,166 @@ +import glob +import json +import os +import sys +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Config, Tokenizer + +filenames_sample = [ + "arxiv_sample.jsonl", + "book_sample.jsonl", + "c4_sample.jsonl", + "cc_2019-30_sample.jsonl", + "cc_2020-05_sample.jsonl", + "cc_2021-04_sample.jsonl", + "cc_2022-05_sample.jsonl", + "cc_2023-06_sample.jsonl", + "github_sample.jsonl", + "stackexchange_sample.jsonl", + "wikipedia_sample.jsonl", +] + +filename_sets = { + "arxiv": "arxiv/arxiv*", + "book": "book/book*", + "c4": "c4/c4-train*", + "common_crawl": "common_crawl/*", + "github": "github/filtered*", + "stackexchange": "stackexchange/stackexchange*", + "wikipedia": "wikipedia/wiki*", +} + + +def prepare_sample( + source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for name in filenames_sample: + if match and match not in name: + continue + + filepath = source_path / name + + if not filepath.is_file(): + raise RuntimeError( + f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + prefix, _ = os.path.splitext(name) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=prefix, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + print(f"Processing {name}") + + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare_full( + source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for set_name, pattern in filename_sets.items(): + if match and match not in set_name: + continue + + is_cc = set_name == "common_crawl" + + filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) + + if not filenames: + raise RuntimeError( + f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=set_name, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for name in filenames: + filepath = source_path / name + + print(f"Processing {name}") + + if is_cc: + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + else: + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + destination_path: Path = Path("data/redpajama_sample"), + sample: bool = True, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" + with open(checkpoint_dir / "lit_config.json") as fp: + config = Config(**json.load(fp)) + + prepare_fn = prepare_sample if sample else prepare_full + prepare_fn( + source_path=source_path, + checkpoint_dir=checkpoint_dir, + destination_path=destination_path, + chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks + match=match, + ) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_slimpajama.py b/scripts/prepare_slimpajama.py new file mode 100644 index 0000000..24ec050 --- /dev/null +++ b/scripts/prepare_slimpajama.py @@ -0,0 +1,105 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for SlimPajama +slimpajama_sets = { + "train": "train/chunk*/*", + "validation": "validation/chunk*/*", + "test": "test/chunk*/*", +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": + continue # we don't want to include the github data + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + + num_processes = cpu_count() + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_starcoder.py b/scripts/prepare_starcoder.py new file mode 100644 index 0000000..838a29f --- /dev/null +++ b/scripts/prepare_starcoder.py @@ -0,0 +1,100 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +import pandas as pd + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + try: + contents = pd.read_parquet(filepath, engine='pyarrow')['content'] + except: + print(f"Error reading {filepath}!!") + continue + for text in contents: + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + assert split == "train" # starcoder only has train data + filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + num_processes = 32 + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file