Skip to content

Commit

Permalink
Merge pull request #42 from pjlab-sys4nlp/data_mix
Browse files Browse the repository at this point in the history
Update gate load vis, update readme
  • Loading branch information
DaizeDong authored Dec 24, 2023
2 parents 0c37546 + 530af61 commit 8245759
Show file tree
Hide file tree
Showing 59 changed files with 9,769 additions and 211 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
docs/imgs/title-favicon.png filter=lfs diff=lfs merge=lfs -text
docs/imgs/MoE-Routing.gif filter=lfs diff=lfs merge=lfs -text
9 changes: 8 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "tokenize",
"type": "python",
"request": "launch",
"module": "smoe.utils.tokenize",
"justMyCode": true
},
{
"name": "Python: Remote Attach",
"type": "python",
"request": "attach",
"connect": {
"host": "SH-IDCA1404-10-140-54-123",
"host": "x.x.x.x",
"port": 5678
},
"pathMappings": [
Expand Down
149 changes: 131 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,146 @@
# train-moe
<div align="center">
<h1>LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training</h1>
<img src="docs/imgs/title-favicon.png" width="200" alt="LLaMA-MoE favicon" style="border-radius: 5%;"><br />
<span style="color:red">📢 <strong><i>A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!!</i></strong></span>
<div>
<a href="https://huggingface.co/llama-moe" target="_blank">🤗 Model Weights</a> | <a href="#" target="_blank">📃 Technical Report</a> | <a href="#quick-start">🚀 Quick Start</a><br />
<a href="docs/Installation.md">⚙️ Installation Guide</a> | <a href="#expert-construction">🚧 Expert Construction</a> | <a href="#continual-pretraining">🚅 Continual Pre-training</a> | <a href="#evaluation">💎 Evaluation</a>
</div>
</div>

[[Installation Guide]](docs/Installation.md) | [[MoEfication Docs]](docs/moefication/README.md) | [[Continual Pre-training Docs]](docs/continual_pretraining/README.md)
<h2 id="llama-moe">🎉 Introduction</h2>

## 🌴 Dependencies
LLaMA-MoE is a series of open-sourced Mixture-of-Expert (MoE) models based on [LLaMA](https://github.com/facebookresearch/llama) and [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama).
We build LLaMA-MoE with the following two steps:
1. Partition LLaMA's FFNs into sparse experts and insert top-K gate for each layer of experts.
2. Continually pre-train the initialized MoE model with an optimized data sampling weights from [Sheared LLaMA](https://arxiv.org/abs/2310.06694) and filtered datasets from [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama).

- Python==3.11.4
- Packages: please check `requirements.txt` (NOTE: `flash-attn` must be properly installed by following [their instructions](https://github.com/Dao-AILab/flash-attention))
![MoE Routing](./docs/imgs/MoE-Routing.gif)

<h2 id="features">🔥 Features</h2>

1. **Lightweight Models**: The total number of model parameters is only 6.7B, which is friendly for deployment and research usage.
2. **Multiple Expert Construction Methods**:
1. Neuron-Independent: Random, Clustering, Co-activation Graph, Gradient ([Zhang et al., 2022](http://arxiv.org/abs/2110.01786), [Zuo et al., 2022](http://arxiv.org/abs/2204.07675))
2. Neuron-Sharing: Inner, Inter (residual)
3. **Multiple MoE Gating Strategies**:
1. TopK Noisy Gate ([Shazeer et al., 2017](http://arxiv.org/abs/1701.06538))
2. Switch Gating ([Fedus et al., 2022](http://arxiv.org/abs/2101.03961))
4. **Fast Continual Pre-training**:
1. FlashAttention-v2 integrated ([Dao, 2023](https://github.com/Dao-AILab/flash-attention))
2. Fast streaming dataset loading
5. **Abundant Monitor Items**:
1. Gate load, gate importance
2. Loss on steps, loss on tokens, balance loss
3. TGS (tokens/GPU/second), MFU (model FLOPs utilization)
4. Other visualization utilities
6. **Dynamic Weight Sampling**:
1. Self-defined static sampling weights
2. Sheared LLaMA's dynamic batch loading ([Xia et al., 2023](http://arxiv.org/abs/2310.06694))


<h2 id="quick-start">🚀 QuickStart</h2>

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_dir = "llama-moe/LLaMA-MoE-v1-3_5B-2_8"
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()
model.to("cuda:0")

input_text = "Suzhou is famous of"
inputs = tokenizer(input_text, return_tensors="pt")
inputs = inputs.to("cuda:0")

pred = model.generate(**inputs, max_length=50, temperature=0.0)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
# Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three
```

<h2 id="performance">📊 Model Performance</h2>

| Model | \#Activated Experts | \#Experts | \#Activated Params | Links |
| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: |
| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) |
| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) |
| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) |

| Model | Average | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) |
| :------------------------------------------------------------------------------------ | :------: | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: |
| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 50.3 | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 |
| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 51.5 | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 |
| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 53.7 | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 |
| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 55.6 | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 |
| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 56.4 | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** |
| **LLaMA-MoE-3.0B** | 55.5 | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 |
| **LLaMA-MoE-3.5B (4/16)** | **57.7** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 |
| **LLaMA-MoE-3.5B (2/8)** | 57.6 | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 |


<h2 id="expert-construction">🚧 Expert Construction</h2>

- Neuron-Independent
- Independent<sub>Random</sub>: `bash ./scripts/moefication/split/run_split_random.sh`
- Independent<sub>Clustering</sub>: `bash ./scripts/moefication/split/run_split_clustering.sh`
- Neuron-Sharing
- Sharing<sub>Inner</sub>: `bash ./scripts/moefication/split/run_split_gradient.sh`
- Sharing<sub>Inter</sub>: `bash ./scripts/moefication/split/run_split_gradient_residual.sh`

For more information, please refer to [Expert Construction docs](docs/moefication/README.md).

<h2 id="continual-pretraining">🚅 Continual Pre-training</h2>

## 🚀 QuickStart

### Tokenization

- RedPajama: `bash scripts/tokenize/redpajama.sh` (Don't forget to change the folder paths.)
Download [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama) into `/path_to_data` and put data from different domains into separate folders:
- `/path_to_data/en_arxiv`
- `/path_to_data/en_book`
- `/path_to_data/en_c4`
- `/path_to_data/en_cc`
- `/path_to_data/en_stack`
- `/path_to_data/en_wikipedia`
- `/path_to_data/github`

Each file should be end with `*.jsonl` and each line looks like:
```
{"id": "id-info", "content": "raw text to be tokenized"}
```

Run the following command to tokenize the data in each folder:

```bash
python -m smoe.utils.tokenize \
-f jsonl \
-t /path_to_tokenizer \
-i /path_to_data/en_arxiv \
-o /path_to_data_tokenized/en_arxiv
```

### Continual Pre-training (CPT)

**NOTICE:** Please create `logs/` folder manually: `mkdir -p logs`
- **NOTICE:** Please create `logs/` folder manually: `mkdir -p logs`
- To run the continual pre-training, please check the [CPT docs](docs/continual_pretraining/README.md).

- LLaMA MoEfication LoRA: `sbatch scripts/cpt/lora.sh`
- LLaMA MoEfication Full-Parameter: `sbatch scripts/cpt/fpt.sh`
<h2 id="evaluation">💎 Evaluation</h2>

## 🤝 Contribution
- For evalution on Natural Questions (NQ), please refer to [opencompass](https://github.com/Spico197/opencompass/tree/main).
- For other tasks, please refer to [lm-eval-harness](https://github.com/spico197/smoe-eval).

- Make sure the Python version `>=3.10` (a strict version contraint for better type hinting)
<h2 id="citation">📑 Citation</h2>

```bash
$ conda install git # upgrade git
$ git clone [email protected]:pjlab-sys4nlp/train-moe.git
$ cd train-moe
$ pip install -e .[dev]
$ pre-commit install
```bibtex
@article{llama-moe-2023,
title={LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training},
author={LLaMA-MoE Team},
journal={arXiv preprint arXiv:},
url={https://arxiv.org/abs/},
year={2023}
}
```

<hr>
<p align="center">LLaMA-MoE Team w/ ❤️</p>
11 changes: 11 additions & 0 deletions docs/Contribution.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# 🤝 Contribution

- Make sure the Python version `>=3.10` (a strict version contraint for better type hinting)

```bash
$ conda install git # upgrade git
$ git clone [email protected]:pjlab-sys4nlp/llama-moe.git
$ cd llama-moe
$ pip install -e .[dev]
$ pre-commit install
```
4 changes: 2 additions & 2 deletions docs/Installation.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🌴 Installation

1. Prepare conda environment: `conda create -n smoe python=3.11` (If your environment name is not `smoe`, you may need to change environment in launching scripts)
2. Add environment variables in `~/.bashrc` (`gcc` is set to newer version for installing `flash-attn`):
2. Add correct environment variables in `~/.bashrc` (`gcc` is set to newer version for installing `flash-attn`). e.g.:
```bash
export PATH=/mnt/petrelfs/share/cuda-11.8/bin:$PATH
export LD_LIBRARY_PATH=/mnt/petrelfs/share/cuda-11.8/lib64:$LD_LIBRARY_PATH
Expand All @@ -11,7 +11,7 @@
3. Take the variables into effect: `source ~/.bashrc`
4. Install PyTorch (CUDA-11.8): `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118`
5. Install dependencies: `pip install -r requirements.txt`
6. Install `flash-attn`: `pip install flash-attn==2.0.1 --no-build-isolation`
6. Install `flash-attn`: `pip install flash-attn==2.0.1 --no-build-isolation`. You may need to follow the [flash-attn installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) to avoid some errors.
7. Install the latest Git: `conda install git`
8. Clone the repo: `git clone [email protected]:pjlab-sys4nlp/train-moe.git` (If you don't setup the ssh key to GitHub, you may not able to clone through ssh. Check the [docs](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account) about it.)
9. Change current directory: `cd train-moe`
Expand Down
59 changes: 35 additions & 24 deletions docs/continual_pretraining/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
# 🚅 Training Guide

## ⚙️ Configuration Instructions
## 🗞️ Executive Scripts

| Description | Path |
| :------------------------ | :------------------------------------------------------------------------------------- |
| LLaMA-MoE 2/16 Experts | `scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh` |
| LLaMA-MoE 4/16 Experts | `scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh` |
| Dynamic<sub>Sheared</sub> | `scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh` |

## 🌴 Other Arguments in Executive Scripts

| Argument Name | Description |
| :------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------- |
| `--dynamic_data_selection` | For different dynamic data sampling strategies, choose one from: `sheared_llama` or `none` (static). Default: `none` |
| `--moe_calculator_score_scale_factor` | Scale factor to multiply after hidden states are procesed by experts. Should be $\frac{\text{\#total experts}}{\text{\#selected}}$. Default: `4.0` |
| `--num_selects` | The number of selected experts. Default: `4` |
| `--gate_balance_loss_weight` | The weight of the balance loss for the gate. Default: `1e-2` |

## 📋 Checklist before Starting an Experiment

- [ ] balance loss weight
- [ ] scale factor
- [ ] learning rate
- [ ] warmup steps
- [ ] evaluation steps
- [ ] logging steps
- [ ] global batch size
- [ ] number of selected experts
- [ ] pretrained model
- [ ] data path
- [ ] GPUs
- [ ] comment

## ⚙️ Configuration Instructions for Slurm Users

For `scripts/cpt/lora.sh` and `scripts/cpt/fpt.sh` files, we could run an experiment via `sbatch`. e.g. `sbatch scripts/cpt/lora.sh` .

Expand Down Expand Up @@ -36,12 +68,6 @@ llama1-7b 16 select 4: 3.49b params

llama1-13b total params: 13,015,864,320 - total mlp params: 8,493,465,600

| total experts | selected | dropped params | added gate params | total params |
| ------------: | -------: | -------------: | ----------------: | ------------: |
| 16 | 8 | 4,246,732,800 | 3,287,040 | 8,772,418,560 |
| 16 | 4 | 6,370,099,200 | 3,287,040 | 6,649,052,160 |
| 16 | 2 | 7,431,782,400 | 3,287,040 | 5,587,368,960 |

## 🧮 Estimation of Training Speed and Tokens

For convenient estimation of the model training speed, we provide some useful information at the very beginning of log files:
Expand All @@ -64,7 +90,7 @@ Based on the above information, the expected time could be calculated.

The tensorboard `logging_dir` could be found at `outputs/<job-name>-<job-id>/runs/<logging-dir>`.

For example, if my job name is `cpt-moe-fpt-bs16-48gpus` in the sbatch file, the tensorboard could be started from that by: `tensorboard --logdir outputs/cpt-moe-fpt-bs16-48gpus-1535835/runs/Jul31_14-12-00_SH-IDCA1404-10-140-54-100` .
For example, if my job name is `cpt-moe-fpt-bs16-48gpus` in the sbatch file, the tensorboard could be started from that by: `tensorboard --logdir outputs/cpt-moe-fpt-bs16-48gpus-1535835/runs/Jul31_14-12-00` .

For multiple tasks with different logging directories, you could run the following command:

Expand All @@ -75,20 +101,5 @@ $ tensorboard --logdir_spec short_name:dir1,short_name2:dir2 --port 8001
Here, the `short_name` is an abbreviation for your task, and the port number could be changed manually if there's a port conflict. e.g.

```bash
$ tensorboard --logdir_spec moe_from_scratch:outputs/cpt-llama-moe-scratch-lora-bs16-1476932/runs/Jul26_21-53-42_SH-IDCA1404-10-140-54-121,moe_lora:outputs/cpt-llama-lora-bs16-1476918/runs/Jul26_21-31-09_SH-IDCA1404-10-140-54-122 --port 8001
$ tensorboard --logdir_spec moe_from_scratch:outputs/cpt-llama-moe-scratch-lora-bs16-1476932/runs/Jul26_21-53-42,moe_lora:outputs/cpt-llama-lora-bs16-1476918/runs/Jul26_21-31-09 --port 8001
```

## 📋 Checklist before Starting an Experiment

- [ ] balance loss weight
- [ ] scale factor
- [ ] learning rate
- [ ] warmup steps
- [ ] evaluation steps
- [ ] logging steps
- [ ] global batch size
- [ ] number of selected experts
- [ ] pretrained model
- [ ] data path
- [ ] GPUs
- [ ] comment
3 changes: 3 additions & 0 deletions docs/imgs/MoE-Routing.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions docs/imgs/title-favicon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from transformers import AutoTokenizer

from smoe.models.llama_moe import LlamaMoEForCausalLM

model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2342244/checkpoint-13600/"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = LlamaMoEForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
model.to("cuda:0")

input_text = "Suzhou is famous of"
inputs = tokenizer(input_text, return_tensors="pt")
inputs = inputs.to("cuda:0")

pred = model.generate(**inputs, max_length=50, temperature=0.0)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
# Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ Pillow==9.4.0
numpy==1.25.0
opencv-python==4.8.1.78
pynvml==11.5.0
PyYaml==6.0.1
pandas<2.1.0
Loading

0 comments on commit 8245759

Please sign in to comment.