diff --git a/docs/moefication/README.md b/docs/moefication/README.md index 0efa2c1..a6988f7 100644 --- a/docs/moefication/README.md +++ b/docs/moefication/README.md @@ -1,224 +1,246 @@ -# LLaMA-Moeficaiton +# MoEfication of LLaMA Model -LLaMA参考代码:**transformers**库,引用路径`transformers.models.llama` +This documentation provides the procedures to convert a LLaMA model to LLaMA-MoE. -
-## 文件结构 -| 文件夹 | 说明 | -| --------------- | ----------------------------------------------------- | -| llama_data | 用于训练MoE门控网络的数据 | -| llama_download | 用于下载LLaMA模型 | -| llama_moe | MoE版LLaMA的模型实现,基于transformers库修改 | -| run_moefication | 用于将普通的LLaMA转化为使用Moefication实现的LLaMA-MoE | +## Procedures -
+The conversion from LLaMA to LLaMA-MoE consists of two steps: -## LLaMA-MoE的加载与使用 +1. **Split.** Create indices sets $S_1,S_2,\dots,S_n$ (Eq. 5 in the technical report) for the each FFN layer in LLaMA. The indices sets indicate the intermediate neurons that should be assigned to experts. Save the indices sets to disk. +2. **Convert.** Create a LLaMA-MoE model from an existing LLaMA checkpoint. Reinitialize the LLaMA-MoE experts by selecting the corressponding neurons in the indices sets. Save the initialized LLaMA-MoE model to disk. -与基于transformers库的LLaMA相同,详情请见示例代码`test_llama_moe_transformers.py`。 -
-## 转换流程 +### Split -### 1. 下载LLaMA模型并转换为transformers库的格式 +#### Random Split (Neuron-Independent) -> 如果已经下载了LLaMA模型,并进行了transformers库的格式转换,可跳过该步骤。 +To randomly split the intermediate neurons in FFNs, you can run: -修改`./llama_download/download_llama.sh`文件前3行中的`MODEL_SIZE`与`TARGET_FOLDER`变量,之后运行`bash download_llama.sh`下载LLaMA模型。 +```shell +bash ./scripts/moefication/split/run_split_random.sh +``` -```sh -PRESIGNED_URL="https://agi.gpt4.org/llama/LLaMA/*" -MODEL_SIZE="7B,13B" # 30B,65B -TARGET_FOLDER="" # where all files should end up +Remember to change the following variables: -declare -A N_SHARD_DICT +```shell +num_experts="" # number of experts in each MoE layer -N_SHARD_DICT["7B"]="0" -N_SHARD_DICT["13B"]="1" -...... +model_path="" # path to the LLaMA checkpoint +save_path="" # path to save the indices sets ``` -运行`./llama_download/convert_llama_weights_to_hf.py`文件,并传入`input_dir`,`model_size`,`output_dir`参数,将原版LLaMA模型转换为transformers库可识别的格式。 -```sh -python convert_llama_weights_to_hf.py --input_dir --model_size 7B --output_dir + +#### Clustering Split (Neuron-Independent) + +To split the intermediate neurons in FFNs by k-means clustering, you can run: + +```shell +bash ./scripts/moefication/split/run_split_clustering.sh ``` -如果上述命令报错,可尝试在运行时设置`PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION`系统变量为`python`,即运行以下命令: +Remember to change the following variables: + +```shell +num_experts="" # number of experts in each MoE layer -```sh -PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python convert_llama_weights_to_hf.py --input_dir --model_size 7B --output_dir +model_path="" # path to the LLaMA checkpoint +save_path="" # path to save the indices sets + +metric="" # metric for clustering, choices: `l2` `cos` +proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj` ``` -
-### 2. 进行LLaMA-MoE转化 -> 运行命令可参照目录`./run_moefication/`下的`sh`文件进行修改,也可直接命令行调用`python`文件。 +#### Co-activation Graph Split (Neuron-Independent) + +> This part is not included in our technical report. + +We also implenmented the co-activation graph based method in [MoEfication](https://arxiv.org/abs/2110.01786) here. + +You need to install [METIS](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) first. Then you can run to following script to perform splitting: + +```shell +bash ./scripts/moefication/split/run_split_graph.sh +``` -Moefication主要包含四个阶段。 +Remember to change the following variables: -1. 将模型中的dense层参数进行分割,得到n个相互独立的子集(即n个专家),**并保存各神经元所对应的专家索引**; -2. 使用LLaMA对预训练数据进行推理,**保存LLaMA的中间层输入输出**,作为门控网络的训练数据。 -2. 训练用于选择专家的门控网络,**并保存门控网络的参数**; -3. **通过保存的索引与门控网络参数**,将LLaMA模型转换为MoE形式。 +```shell +num_experts="" # number of experts in each MoE layer -
+model_path="" # path to the LLaMA checkpoint +save_path="" # path to save the indices sets -#### 2.1 进行专家划分,保存各神经元所对应的专家索引 (split) +metric="" # metric to measure the sparsity, choices: `l1_norm` `l2_norm` `plain` +proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj` +``` -> Moefication原文提供了***clustering***与***graph***两种参数分割方式,两者效果近似相同。 -> -> 此处使用了更加通用的***clustering***分割方式。 -运行`./run_moefication/llama_split_clustering.py`文件,并传入下述参数: -| 参数名 | 说明 | -| ----------- | ------------------------------------------------------------ | -| model_path | transformers库格式的LLaMA模型的路径 | -| save_path | 保存专家划分后的index文件的路径,会该位置下自动创建名如"xxx-xxxExpert-Split-Clustering"的文件夹进行保存 | -| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" | -| num_experts | 专家划分的数量 | +#### Gradient Split -下面为运行命令的范例: +Before performing gradient-based splitting (Eq. 8 in the technical report), you need to prepare a bunch of pretraining data and group them into different clusters by running: -```sh -python llama_split_clustering.py --model_path /llama_7B --save_path ../llama_moe_temp_files --num_experts 8 +```shell +python smoe/entrypoint/text_clustering.py ``` -上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering`路径下生成若干文件,保存了dense层各神经元所对应的专家索引,名称格式如下: +Then, you need to run the following script to get the importance vector $v$ for the intermediate neurons in each layer: -```sh -layers.0.mlp.gate_proj.weight -layers.1.mlp.gate_proj.weight -layers.2.mlp.gate_proj.weight -...... +```shell +bash scripts/moefication/split/run_split_gradient_get_grads.sh ``` -
+Remember to change the following variables: -#### 2.2 获取LLaMA的中间层特征 (get_hidden_features) +```shell +dataset_dir="" # path to clustered data +pretrained_model="" # path to the LLaMA checkpoint +tokenizer_path="" # path to the LLaMA tokenizer +save_path="" # path to save the indices sets -##### 2.2.1 预训练数据准备 +accumulate_level="" # should be set to `sample` +kernel="" # should be set to `l1_norm` +importance_type="" # should be set to `feature_change` +``` -训练MoE门控网络需要使用LLaMA的预训练数据,下图为LLaMA预训练时各数据集的采样比例。 -image-20230718174945369 -此处在`./llama_data`文件夹下提供了上述数据集的子集,使用前先解压`zip`文件,得到以`jsonl`格式存储的数据集。 +##### Neuron Independent -提供的数据集中各个数据集的token数量基本相近,具体的比例分配在`./run_moefication/llama_get_hidden_features.py`中实现。 +> This part is not included in our technical report. -##### 2.2.2 使用LLaMA进行推理 +You can also split the intermediate neurons in a neuron-independent manner by treating the expert split as a task assignment problem. To perform the split, you can run: + +```shell +bash ./scripts/moefication/split/run_split_gradient.sh +``` -运行`./run_moefication/llama_select_mlp.py`文件,并传入下述参数: +Remember to change the following variables: -| 参数名 | 说明 | -| --------------------- | ------------------------------------------------------------ | -| model_path | transformers库格式的LLaMA模型的路径 | -| train_data_path | LLaMA预训练数据的保存路径 (**步骤2.2.1**) | -| train_data_cache_path | LLaMA预训练数据的缓存保存路径,用于保存与读取数据缓存,减少数据处理时间 | -| save_path | 保存门控网络参数文件的位置,会该位置下自动创建名如"xxx-Hidden-Features"的路径进行保存 | -| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" | -| data_use_percent | 所有数据集中数据的使用比例,用于调节训练使用的数据量
3.5M的token在LLaMA 7B上得到的中间层输出文件总大小约4.5T | -| batch_size | 单次推理的batch_size,根据显存大小调节 | -| save_interval | 保存参数的batch间隔,调大会影响显存占用,但可以减少保存的文件个数 | +```shell +expert_num="" # number of experts in each MoE layer +expert_size="" # intermediate neurons in each expert +share_neurons="False" ######### SET AS FLASE TO BE NEURON-INDEPENDENT ######### -推理代码使用了torch原生多卡形式,需使用`python -m torch.distributed.launch`运行,以下为运行命令的范例: +model_path="" # path to the LLaMA checkpoint +score_file_path="" # path to the score files generated above +save_path="" # path to save the indices sets +visualization_path="" # path to save the visualization results -```sh -python -m torch.distributed.launch --nproc_per_node=8 llama_get_hidden_features.py --model_path /llama_7B --train_data_path ../llama_data --train_data_cache_path ../llama_data_cache --save_path ../llama_moe_temp_files --data_use_percent 0.01 --save_interval 1 --batch_size 4 +criterion="" # criterion to judge the importance of neurons, should be set to `max` +proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj` ``` -上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-Hidden-Features`路径下生成若干文件,保存了中间层输入输出,名称格式如下: -```sh --- hidden_inputs - 0_0.pth - 0_1.pth - 0_2.pth - ...... --- hidden_gate_outputs - 0_0.pth - 0_1.pth - 0_2.pth - ...... + +##### Inner-Sharing + +Here we use the same entrance as the **Neuron Independent** strategy above for gradient split. + +```shell +bash ./scripts/moefication/split/run_split_gradient.sh ``` -
+Remember to change the following variables: -#### 2.3 训练专家选择网络,保存门控网络的参数 (select) +```shell +expert_num="" # number of experts in each MoE layer +expert_size="" # intermediate neurons in each expert +share_neurons="True" ######### SET AS TRUE TO BE INNER-SHARING ######### -> Moefication原文提供了***Similarity***与***MLP***两种门控机制。 -> -> 此处选择常规的MLP作为门控网络,网络输出对应各个专家的选择分数,该方案效果较好。 +model_path="" # path to the LLaMA checkpoint +score_file_path="" # path to the score files generated above +save_path="" # path to save the indices sets +visualization_path="" # path to save the visualization results + +criterion="" # criterion to judge the importance of neurons, should be set to `max` +proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj` +``` -运行`./run_moefication/llama_select_mlp.py`文件,并传入下述参数: -| 参数名 | 说明 | -| -------------------- | ------------------------------------------------------------ | -| model_path | transformers库格式的LLaMA模型的路径 | -| split_file_path | **步骤2.1**中保存的专家索引路径 | -| hidden_features_path | **步骤2.2**中保存的LLaMA中间层特征路径 | -| save_path | 保存门控网络参数文件的位置,会该位置下自动创建名如"xxx-xxxExpert-Select-MLP"的路径进行保存 | -| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" | -| select_criterion | 专家划分的依据指标,有plain、positive、l2_norm三种方式,默认使用l2_norm | -| num_experts | 专家划分的数量 | -| num_selects | 专家选择的数量 | -| specify_layer | 指定对哪些层进行划分,用于并行执行划分操作,**留空则对所有层进行划分**
如LLaMA 7B有32层,则可使用4条命令,分别指定0-7、8-15、16-23、24-32进行并行训练 | -| use_softmax | 添加后则使用Softmax激活MoE Gate输出,建议添加 | -下面为运行命令的范例: +##### Inter-Sharing (Residual MoE) -```sh -python llama_split_clustering.py --model_path /llama_7B --split_file_path ../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering --hidden_features_path ../llama_moe_temp_files/llama_7B-Hidden-Features --save_path ../llama_moe_temp_files --select_criterion l2_norm --num_experts 8 --num_selects 2 --use_softmax +You can run the following script to perform inter-sharing split: + +```shell +bash ./scripts/moefication/split/run_split_gradient_residual.sh ``` -上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-8Expert-Select-MLP`路径下生成若干文件,保存了各个门控网络的参数,名称格式如下: +Remember to change the following variables: + +```shell +expert_num_moe="" # number of non-residual experts +expert_num_residual="" # number of residual experts +expert_size="" # intermediate neurons in each expert +share_neurons="" # Whether to share neurons in non-residual experts + +model_path="" # path to the LLaMA checkpoint +score_file_path="" # path to the score files generated above +save_path="" # path to save the indices sets +visualization_path="" # path to save the visualization results -```sh -layers.0.mlp.gate_proj.weight -layers.1.mlp.gate_proj.weight -layers.2.mlp.gate_proj.weight -...... +criterion="" # criterion to judge the importance of neurons, should be set to `max` +proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj` ``` -
-#### 2.4 转换LLaMA模型并保存 (convert) -修改`./run_moefication/llama_convert.py.py`文件第4行的路径为项目根路径。 +### Convert + +#### Convert LLaMA-MoE from Neuron-Independent Methods + +Run the following script: -```python -sys.path.append("") # 修改为项目根路径,如"/home/dongdz/workspace/llama-moefication/" +```shell +bash ./scripts/moefication/convert/run_convert.sh ``` -运行`./run_moefication/llama_convert.py`文件,并传入下述参数: -| 参数名 | 说明 | -| ---------------- | ------------------------------------------------------------ | -| model_path | transformers库格式的LLaMA模型的路径 | -| split_file_path | **步骤2.1**中保存的专家索引路径 | -| select_file_path | **步骤2.3**中保存的门控网络参数路径 | -| save_path | 保存转换后的transformers库格式的LLaMA-MoE模型的路径 | -| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" | -| num_experts | 专家划分的数量 | -| num_selects | 专家选择的数量 | -| convert_type | 转换的模型类别,可选项如下:
LlamaMoEModel、LlamaMoEForCausalLM、LlamaMoEForSequenceClassification | -下面为运行命令的范例: +#### Convert LLaMA-MoE from Inner-Sharing Methods -```sh -python llama_convert.py --model_path /llama_7B --split_file_path ../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering --select_file_path ../llama_moe_temp_files/llama_7B-8Expert-Select-MLP --save_path /llama_7B-MoE --num_experts 8 --num_selects 2 --convert_type LlamaMoEForCausalLM +Run the following script: + +```shell +bash ./scripts/moefication/convert/run_convert_gradient.sh ``` -上述命令运行结束后,会在`/7B-MoE`路径下保存LLaMA-MoE模型。 -
-## 代码相关——MoE实现逻辑 +#### Convert LLaMA-MoE from Inter-Sharing Methods (Residual MoE) + +Run the following script: + +```shell +bash ./scripts/moefication/convert/run_convert_gradient_residual.sh +``` + + + +## File Structure + +``` +--smoe + -- scripts + -- moefication + -- convert + -- get_hidden_features (deprecated) + -- prune (deprecated) + -- select (deprecated) + -- split + -- smoe + -- entrypoint + -- moefication +``` + + + + -可以自己看`./llama_moe/moe_utils/`下的代码,其中`moe_layers.py`是封装好的MoE层,用于直接调用。 diff --git a/docs/moefication/readme-image.png b/docs/moefication/readme-image.png deleted file mode 100644 index 8917c2d..0000000 Binary files a/docs/moefication/readme-image.png and /dev/null differ diff --git a/scripts/moefication/convert/run_convert_gradient_residual.sh b/scripts/moefication/convert/run_convert_gradient_residual.sh index a516d41..0a21cc1 100644 --- a/scripts/moefication/convert/run_convert_gradient_residual.sh +++ b/scripts/moefication/convert/run_convert_gradient_residual.sh @@ -13,7 +13,7 @@ expert_size=1376 # 688 1376 2752 5504 11008 # 864 1728 3456 6912 13824 -score_scale_factor_residual=4.0 # 4.0 8.0 12.0 16.0 +score_scale_factor_residual=1.0 # 4.0 8.0 12.0 16.0 score_scale_factor=4.0 # 4.0 8.0 12.0 16.0 convert_type=LlamaMoEResidualForCausalLM # LlamaMoEResidualModel LlamaMoEResidualForCausalLM LlamaMoEResidualForSequenceClassification diff --git a/scripts/moefication/split/run_split_graph.py b/scripts/moefication/split/run_split_graph.py deleted file mode 100644 index 96705bb..0000000 --- a/scripts/moefication/split/run_split_graph.py +++ /dev/null @@ -1,68 +0,0 @@ -import subprocess - -# Define the bash commands - - -bash_commands = """ -# llama_7B llama_13B llama_30B llama_base -# llama2_7B llama2_13B llama2_30B llama2_base -llama_size=llama_13B - -num_experts=16 # 8 16 -metric=l1_norm # l1_norm l2_norm plain -template=layers.{}.mlp.up_proj.weight # gate_proj up_proj -threshold=1 - -data_path=/mnt/petrelfs/share_data/quxiaoye -model_path=${data_path}/models/${llama_size} -save_path=${data_path}/moefication_results/split/${llama_size}-${num_experts}Expert-Split-Graph-${metric}/ -hidden_features_path=${data_path}/moefication_results/features/${llama_size}-Hidden-Features - -gpus=0 -cpus=16 - -# STEP1 - -for specify_layer in {0..39}; do - OMP_NUM_THREADS=2 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \ - python -m smoe.entrypoint.moefication.llama_split_graph \ - --model_path ${model_path} \ - --save_path ${save_path} \ - --specify_layer ${specify_layer} \ - --template ${template} \ - --num_experts ${num_experts} \ - --threshold ${threshold} \ - --metric ${metric} \ - --hidden_features_path ${hidden_features_path} & - sleep 0.7 -done -wait - -# STEP2 - -gpmetis_run=/mnt/petrelfs/share_data/quxiaoye/metis_for_graph_split/bin/gpmetis -template1=layers. -template2=.mlp.up_proj.weight - -for layer in {0..39}; do - OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \ - ${gpmetis_run} ${save_path}/${template1}${layer}${template2} ${num_experts} & - sleep 0.7 -done -wait -# STEP3 - -template3=.part.${num_experts} - -for layer in {0..39}; do - OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \ - python -m smoe.entrypoint.moefication.llama_split_graph_trans_gp \ - --gpmetised_file_path ${save_path}/${template1}${layer}${template2}${template3} & - sleep 0.7 -done -wait -chmod -R 755 ${save_path} >/dev/null 2>&1 -""" - -# Execute the bash commands using Python's subprocess module -subprocess.run(bash_commands, shell=True, executable="/bin/bash") diff --git a/scripts/moefication/split/run_split_graph.sh b/scripts/moefication/split/run_split_graph.sh index 589c762..b892014 100644 --- a/scripts/moefication/split/run_split_graph.sh +++ b/scripts/moefication/split/run_split_graph.sh @@ -4,9 +4,9 @@ # llama2_7B llama2_13B llama2_30B llama2_base llama_size=llama_13B -num_experts=8 # 8 16 -metric=l2_norm # l1_norm l2_norm plain -template=layers.{}.mlp.up_proj.weight # gate_proj up_proj +num_experts=16 # 8 16 +metric=l1_norm # l1_norm l2_norm plain +proj_type=up_proj # gate_proj up_proj threshold=1 data_path=/mnt/petrelfs/share_data/quxiaoye @@ -25,7 +25,7 @@ for specify_layer in {0..39}; do --model_path ${model_path} \ --save_path ${save_path} \ --specify_layer ${specify_layer} \ - --template ${template} \ + --template layers.{}.mlp.${proj_type}.weight \ --num_experts ${num_experts} \ --threshold ${threshold} \ --metric ${metric} \ @@ -38,7 +38,7 @@ wait gpmetis_run=/mnt/petrelfs/share_data/quxiaoye/metis_for_graph_split/bin/gpmetis template1=layers. -template2=.mlp.up_proj.weight +template2=.mlp.${proj_type}.weight for layer in {0..39}; do OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \ @@ -46,6 +46,7 @@ for layer in {0..39}; do sleep 0.7 done wait + # STEP3 template3=.part.${num_experts} @@ -57,4 +58,5 @@ for layer in {0..39}; do sleep 0.7 done wait + chmod -R 755 ${save_path} >/dev/null 2>&1 diff --git a/scripts/moefication/split/run_split_random.sh b/scripts/moefication/split/run_split_random.sh index da2485b..aecdc8c 100644 --- a/scripts/moefication/split/run_split_random.sh +++ b/scripts/moefication/split/run_split_random.sh @@ -5,8 +5,7 @@ # open_llama_7b llama_size="open_llama_7b" -num_experts=8 # 8 16 -proj_type=gate_proj # gate_proj up_proj +num_experts=8 # 8 16 data_path=/mnt/petrelfs/share_data/quxiaoye model_path=${data_path}/models/${llama_size} @@ -18,7 +17,7 @@ OMP_NUM_THREADS=2 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${ python -m smoe.entrypoint.moefication.llama_split_random \ --model_path ${model_path} \ --save_path ${save_path} \ - --template layers.{}.mlp.${proj_type}.weight \ + --template layers.{}.mlp.up_proj.weight \ --num_experts ${num_experts} chmod -R 755 ${save_path} >/dev/null 2>&1 diff --git a/smoe/entrypoint/moefication/llama_split_random.py b/smoe/entrypoint/moefication/llama_split_random.py index b9c4146..eb26dad 100644 --- a/smoe/entrypoint/moefication/llama_split_random.py +++ b/smoe/entrypoint/moefication/llama_split_random.py @@ -11,7 +11,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default="/home/data/models/llama-transformers/7B") parser.add_argument('--save_path', type=str, default="/home/dongdz/workspace/moefication/llama_moe_temp_files/") - parser.add_argument('--template', type=str, default='layers.{}.mlp.gate_proj.weight') + parser.add_argument('--template', type=str, default='layers.{}.mlp.up_proj.weight') parser.add_argument('--num_experts', type=int, default=8, help='number of experts') args = parser.parse_args()