diff --git a/docs/moefication/README.md b/docs/moefication/README.md
index 0efa2c1..a6988f7 100644
--- a/docs/moefication/README.md
+++ b/docs/moefication/README.md
@@ -1,224 +1,246 @@
-# LLaMA-Moeficaiton
+# MoEfication of LLaMA Model
-LLaMA参考代码:**transformers**库,引用路径`transformers.models.llama`
+This documentation provides the procedures to convert a LLaMA model to LLaMA-MoE.
-
-## 文件结构
-| 文件夹 | 说明 |
-| --------------- | ----------------------------------------------------- |
-| llama_data | 用于训练MoE门控网络的数据 |
-| llama_download | 用于下载LLaMA模型 |
-| llama_moe | MoE版LLaMA的模型实现,基于transformers库修改 |
-| run_moefication | 用于将普通的LLaMA转化为使用Moefication实现的LLaMA-MoE |
+## Procedures
-
+The conversion from LLaMA to LLaMA-MoE consists of two steps:
-## LLaMA-MoE的加载与使用
+1. **Split.** Create indices sets $S_1,S_2,\dots,S_n$ (Eq. 5 in the technical report) for the each FFN layer in LLaMA. The indices sets indicate the intermediate neurons that should be assigned to experts. Save the indices sets to disk.
+2. **Convert.** Create a LLaMA-MoE model from an existing LLaMA checkpoint. Reinitialize the LLaMA-MoE experts by selecting the corressponding neurons in the indices sets. Save the initialized LLaMA-MoE model to disk.
-与基于transformers库的LLaMA相同,详情请见示例代码`test_llama_moe_transformers.py`。
-
-## 转换流程
+### Split
-### 1. 下载LLaMA模型并转换为transformers库的格式
+#### Random Split (Neuron-Independent)
-> 如果已经下载了LLaMA模型,并进行了transformers库的格式转换,可跳过该步骤。
+To randomly split the intermediate neurons in FFNs, you can run:
-修改`./llama_download/download_llama.sh`文件前3行中的`MODEL_SIZE`与`TARGET_FOLDER`变量,之后运行`bash download_llama.sh`下载LLaMA模型。
+```shell
+bash ./scripts/moefication/split/run_split_random.sh
+```
-```sh
-PRESIGNED_URL="https://agi.gpt4.org/llama/LLaMA/*"
-MODEL_SIZE="7B,13B" # 30B,65B
-TARGET_FOLDER="" # where all files should end up
+Remember to change the following variables:
-declare -A N_SHARD_DICT
+```shell
+num_experts="" # number of experts in each MoE layer
-N_SHARD_DICT["7B"]="0"
-N_SHARD_DICT["13B"]="1"
-......
+model_path="" # path to the LLaMA checkpoint
+save_path="" # path to save the indices sets
```
-运行`./llama_download/convert_llama_weights_to_hf.py`文件,并传入`input_dir`,`model_size`,`output_dir`参数,将原版LLaMA模型转换为transformers库可识别的格式。
-```sh
-python convert_llama_weights_to_hf.py --input_dir --model_size 7B --output_dir
+
+#### Clustering Split (Neuron-Independent)
+
+To split the intermediate neurons in FFNs by k-means clustering, you can run:
+
+```shell
+bash ./scripts/moefication/split/run_split_clustering.sh
```
-如果上述命令报错,可尝试在运行时设置`PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION`系统变量为`python`,即运行以下命令:
+Remember to change the following variables:
+
+```shell
+num_experts="" # number of experts in each MoE layer
-```sh
-PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python convert_llama_weights_to_hf.py --input_dir --model_size 7B --output_dir
+model_path="" # path to the LLaMA checkpoint
+save_path="" # path to save the indices sets
+
+metric="" # metric for clustering, choices: `l2` `cos`
+proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj`
```
-
-### 2. 进行LLaMA-MoE转化
-> 运行命令可参照目录`./run_moefication/`下的`sh`文件进行修改,也可直接命令行调用`python`文件。
+#### Co-activation Graph Split (Neuron-Independent)
+
+> This part is not included in our technical report.
+
+We also implenmented the co-activation graph based method in [MoEfication](https://arxiv.org/abs/2110.01786) here.
+
+You need to install [METIS](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) first. Then you can run to following script to perform splitting:
+
+```shell
+bash ./scripts/moefication/split/run_split_graph.sh
+```
-Moefication主要包含四个阶段。
+Remember to change the following variables:
-1. 将模型中的dense层参数进行分割,得到n个相互独立的子集(即n个专家),**并保存各神经元所对应的专家索引**;
-2. 使用LLaMA对预训练数据进行推理,**保存LLaMA的中间层输入输出**,作为门控网络的训练数据。
-2. 训练用于选择专家的门控网络,**并保存门控网络的参数**;
-3. **通过保存的索引与门控网络参数**,将LLaMA模型转换为MoE形式。
+```shell
+num_experts="" # number of experts in each MoE layer
-
+model_path="" # path to the LLaMA checkpoint
+save_path="" # path to save the indices sets
-#### 2.1 进行专家划分,保存各神经元所对应的专家索引 (split)
+metric="" # metric to measure the sparsity, choices: `l1_norm` `l2_norm` `plain`
+proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj`
+```
-> Moefication原文提供了***clustering***与***graph***两种参数分割方式,两者效果近似相同。
->
-> 此处使用了更加通用的***clustering***分割方式。
-运行`./run_moefication/llama_split_clustering.py`文件,并传入下述参数:
-| 参数名 | 说明 |
-| ----------- | ------------------------------------------------------------ |
-| model_path | transformers库格式的LLaMA模型的路径 |
-| save_path | 保存专家划分后的index文件的路径,会该位置下自动创建名如"xxx-xxxExpert-Split-Clustering"的文件夹进行保存 |
-| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" |
-| num_experts | 专家划分的数量 |
+#### Gradient Split
-下面为运行命令的范例:
+Before performing gradient-based splitting (Eq. 8 in the technical report), you need to prepare a bunch of pretraining data and group them into different clusters by running:
-```sh
-python llama_split_clustering.py --model_path /llama_7B --save_path ../llama_moe_temp_files --num_experts 8
+```shell
+python smoe/entrypoint/text_clustering.py
```
-上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering`路径下生成若干文件,保存了dense层各神经元所对应的专家索引,名称格式如下:
+Then, you need to run the following script to get the importance vector $v$ for the intermediate neurons in each layer:
-```sh
-layers.0.mlp.gate_proj.weight
-layers.1.mlp.gate_proj.weight
-layers.2.mlp.gate_proj.weight
-......
+```shell
+bash scripts/moefication/split/run_split_gradient_get_grads.sh
```
-
+Remember to change the following variables:
-#### 2.2 获取LLaMA的中间层特征 (get_hidden_features)
+```shell
+dataset_dir="" # path to clustered data
+pretrained_model="" # path to the LLaMA checkpoint
+tokenizer_path="" # path to the LLaMA tokenizer
+save_path="" # path to save the indices sets
-##### 2.2.1 预训练数据准备
+accumulate_level="" # should be set to `sample`
+kernel="" # should be set to `l1_norm`
+importance_type="" # should be set to `feature_change`
+```
-训练MoE门控网络需要使用LLaMA的预训练数据,下图为LLaMA预训练时各数据集的采样比例。
-
-此处在`./llama_data`文件夹下提供了上述数据集的子集,使用前先解压`zip`文件,得到以`jsonl`格式存储的数据集。
+##### Neuron Independent
-提供的数据集中各个数据集的token数量基本相近,具体的比例分配在`./run_moefication/llama_get_hidden_features.py`中实现。
+> This part is not included in our technical report.
-##### 2.2.2 使用LLaMA进行推理
+You can also split the intermediate neurons in a neuron-independent manner by treating the expert split as a task assignment problem. To perform the split, you can run:
+
+```shell
+bash ./scripts/moefication/split/run_split_gradient.sh
+```
-运行`./run_moefication/llama_select_mlp.py`文件,并传入下述参数:
+Remember to change the following variables:
-| 参数名 | 说明 |
-| --------------------- | ------------------------------------------------------------ |
-| model_path | transformers库格式的LLaMA模型的路径 |
-| train_data_path | LLaMA预训练数据的保存路径 (**步骤2.2.1**) |
-| train_data_cache_path | LLaMA预训练数据的缓存保存路径,用于保存与读取数据缓存,减少数据处理时间 |
-| save_path | 保存门控网络参数文件的位置,会该位置下自动创建名如"xxx-Hidden-Features"的路径进行保存 |
-| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" |
-| data_use_percent | 所有数据集中数据的使用比例,用于调节训练使用的数据量
3.5M的token在LLaMA 7B上得到的中间层输出文件总大小约4.5T |
-| batch_size | 单次推理的batch_size,根据显存大小调节 |
-| save_interval | 保存参数的batch间隔,调大会影响显存占用,但可以减少保存的文件个数 |
+```shell
+expert_num="" # number of experts in each MoE layer
+expert_size="" # intermediate neurons in each expert
+share_neurons="False" ######### SET AS FLASE TO BE NEURON-INDEPENDENT #########
-推理代码使用了torch原生多卡形式,需使用`python -m torch.distributed.launch`运行,以下为运行命令的范例:
+model_path="" # path to the LLaMA checkpoint
+score_file_path="" # path to the score files generated above
+save_path="" # path to save the indices sets
+visualization_path="" # path to save the visualization results
-```sh
-python -m torch.distributed.launch --nproc_per_node=8 llama_get_hidden_features.py --model_path /llama_7B --train_data_path ../llama_data --train_data_cache_path ../llama_data_cache --save_path ../llama_moe_temp_files --data_use_percent 0.01 --save_interval 1 --batch_size 4
+criterion="" # criterion to judge the importance of neurons, should be set to `max`
+proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj`
```
-上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-Hidden-Features`路径下生成若干文件,保存了中间层输入输出,名称格式如下:
-```sh
--- hidden_inputs
- 0_0.pth
- 0_1.pth
- 0_2.pth
- ......
--- hidden_gate_outputs
- 0_0.pth
- 0_1.pth
- 0_2.pth
- ......
+
+##### Inner-Sharing
+
+Here we use the same entrance as the **Neuron Independent** strategy above for gradient split.
+
+```shell
+bash ./scripts/moefication/split/run_split_gradient.sh
```
-
+Remember to change the following variables:
-#### 2.3 训练专家选择网络,保存门控网络的参数 (select)
+```shell
+expert_num="" # number of experts in each MoE layer
+expert_size="" # intermediate neurons in each expert
+share_neurons="True" ######### SET AS TRUE TO BE INNER-SHARING #########
-> Moefication原文提供了***Similarity***与***MLP***两种门控机制。
->
-> 此处选择常规的MLP作为门控网络,网络输出对应各个专家的选择分数,该方案效果较好。
+model_path="" # path to the LLaMA checkpoint
+score_file_path="" # path to the score files generated above
+save_path="" # path to save the indices sets
+visualization_path="" # path to save the visualization results
+
+criterion="" # criterion to judge the importance of neurons, should be set to `max`
+proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj`
+```
-运行`./run_moefication/llama_select_mlp.py`文件,并传入下述参数:
-| 参数名 | 说明 |
-| -------------------- | ------------------------------------------------------------ |
-| model_path | transformers库格式的LLaMA模型的路径 |
-| split_file_path | **步骤2.1**中保存的专家索引路径 |
-| hidden_features_path | **步骤2.2**中保存的LLaMA中间层特征路径 |
-| save_path | 保存门控网络参数文件的位置,会该位置下自动创建名如"xxx-xxxExpert-Select-MLP"的路径进行保存 |
-| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" |
-| select_criterion | 专家划分的依据指标,有plain、positive、l2_norm三种方式,默认使用l2_norm |
-| num_experts | 专家划分的数量 |
-| num_selects | 专家选择的数量 |
-| specify_layer | 指定对哪些层进行划分,用于并行执行划分操作,**留空则对所有层进行划分**
如LLaMA 7B有32层,则可使用4条命令,分别指定0-7、8-15、16-23、24-32进行并行训练 |
-| use_softmax | 添加后则使用Softmax激活MoE Gate输出,建议添加 |
-下面为运行命令的范例:
+##### Inter-Sharing (Residual MoE)
-```sh
-python llama_split_clustering.py --model_path /llama_7B --split_file_path ../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering --hidden_features_path ../llama_moe_temp_files/llama_7B-Hidden-Features --save_path ../llama_moe_temp_files --select_criterion l2_norm --num_experts 8 --num_selects 2 --use_softmax
+You can run the following script to perform inter-sharing split:
+
+```shell
+bash ./scripts/moefication/split/run_split_gradient_residual.sh
```
-上述命令运行结束后,会在`../llama_moe_temp_files/llama_7B-8Expert-Select-MLP`路径下生成若干文件,保存了各个门控网络的参数,名称格式如下:
+Remember to change the following variables:
+
+```shell
+expert_num_moe="" # number of non-residual experts
+expert_num_residual="" # number of residual experts
+expert_size="" # intermediate neurons in each expert
+share_neurons="" # Whether to share neurons in non-residual experts
+
+model_path="" # path to the LLaMA checkpoint
+score_file_path="" # path to the score files generated above
+save_path="" # path to save the indices sets
+visualization_path="" # path to save the visualization results
-```sh
-layers.0.mlp.gate_proj.weight
-layers.1.mlp.gate_proj.weight
-layers.2.mlp.gate_proj.weight
-......
+criterion="" # criterion to judge the importance of neurons, should be set to `max`
+proj_type="" # weights to perform clustering, choices: `up_proj` `gate_proj`
```
-
-#### 2.4 转换LLaMA模型并保存 (convert)
-修改`./run_moefication/llama_convert.py.py`文件第4行的路径为项目根路径。
+### Convert
+
+#### Convert LLaMA-MoE from Neuron-Independent Methods
+
+Run the following script:
-```python
-sys.path.append("") # 修改为项目根路径,如"/home/dongdz/workspace/llama-moefication/"
+```shell
+bash ./scripts/moefication/convert/run_convert.sh
```
-运行`./run_moefication/llama_convert.py`文件,并传入下述参数:
-| 参数名 | 说明 |
-| ---------------- | ------------------------------------------------------------ |
-| model_path | transformers库格式的LLaMA模型的路径 |
-| split_file_path | **步骤2.1**中保存的专家索引路径 |
-| select_file_path | **步骤2.3**中保存的门控网络参数路径 |
-| save_path | 保存转换后的transformers库格式的LLaMA-MoE模型的路径 |
-| templates | 专家划分所依据的网络参数模板,默认为"layers.{}.mlp.gate_proj.weight" |
-| num_experts | 专家划分的数量 |
-| num_selects | 专家选择的数量 |
-| convert_type | 转换的模型类别,可选项如下:
LlamaMoEModel、LlamaMoEForCausalLM、LlamaMoEForSequenceClassification |
-下面为运行命令的范例:
+#### Convert LLaMA-MoE from Inner-Sharing Methods
-```sh
-python llama_convert.py --model_path /llama_7B --split_file_path ../llama_moe_temp_files/llama_7B-8Expert-Split-Clustering --select_file_path ../llama_moe_temp_files/llama_7B-8Expert-Select-MLP --save_path /llama_7B-MoE --num_experts 8 --num_selects 2 --convert_type LlamaMoEForCausalLM
+Run the following script:
+
+```shell
+bash ./scripts/moefication/convert/run_convert_gradient.sh
```
-上述命令运行结束后,会在`/7B-MoE`路径下保存LLaMA-MoE模型。
-
-## 代码相关——MoE实现逻辑
+#### Convert LLaMA-MoE from Inter-Sharing Methods (Residual MoE)
+
+Run the following script:
+
+```shell
+bash ./scripts/moefication/convert/run_convert_gradient_residual.sh
+```
+
+
+
+## File Structure
+
+```
+--smoe
+ -- scripts
+ -- moefication
+ -- convert
+ -- get_hidden_features (deprecated)
+ -- prune (deprecated)
+ -- select (deprecated)
+ -- split
+ -- smoe
+ -- entrypoint
+ -- moefication
+```
+
+
+
+
-可以自己看`./llama_moe/moe_utils/`下的代码,其中`moe_layers.py`是封装好的MoE层,用于直接调用。
diff --git a/docs/moefication/readme-image.png b/docs/moefication/readme-image.png
deleted file mode 100644
index 8917c2d..0000000
Binary files a/docs/moefication/readme-image.png and /dev/null differ
diff --git a/scripts/moefication/convert/run_convert_gradient_residual.sh b/scripts/moefication/convert/run_convert_gradient_residual.sh
index a516d41..0a21cc1 100644
--- a/scripts/moefication/convert/run_convert_gradient_residual.sh
+++ b/scripts/moefication/convert/run_convert_gradient_residual.sh
@@ -13,7 +13,7 @@ expert_size=1376
# 688 1376 2752 5504 11008
# 864 1728 3456 6912 13824
-score_scale_factor_residual=4.0 # 4.0 8.0 12.0 16.0
+score_scale_factor_residual=1.0 # 4.0 8.0 12.0 16.0
score_scale_factor=4.0 # 4.0 8.0 12.0 16.0
convert_type=LlamaMoEResidualForCausalLM # LlamaMoEResidualModel LlamaMoEResidualForCausalLM LlamaMoEResidualForSequenceClassification
diff --git a/scripts/moefication/split/run_split_graph.py b/scripts/moefication/split/run_split_graph.py
deleted file mode 100644
index 96705bb..0000000
--- a/scripts/moefication/split/run_split_graph.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import subprocess
-
-# Define the bash commands
-
-
-bash_commands = """
-# llama_7B llama_13B llama_30B llama_base
-# llama2_7B llama2_13B llama2_30B llama2_base
-llama_size=llama_13B
-
-num_experts=16 # 8 16
-metric=l1_norm # l1_norm l2_norm plain
-template=layers.{}.mlp.up_proj.weight # gate_proj up_proj
-threshold=1
-
-data_path=/mnt/petrelfs/share_data/quxiaoye
-model_path=${data_path}/models/${llama_size}
-save_path=${data_path}/moefication_results/split/${llama_size}-${num_experts}Expert-Split-Graph-${metric}/
-hidden_features_path=${data_path}/moefication_results/features/${llama_size}-Hidden-Features
-
-gpus=0
-cpus=16
-
-# STEP1
-
-for specify_layer in {0..39}; do
- OMP_NUM_THREADS=2 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \
- python -m smoe.entrypoint.moefication.llama_split_graph \
- --model_path ${model_path} \
- --save_path ${save_path} \
- --specify_layer ${specify_layer} \
- --template ${template} \
- --num_experts ${num_experts} \
- --threshold ${threshold} \
- --metric ${metric} \
- --hidden_features_path ${hidden_features_path} &
- sleep 0.7
-done
-wait
-
-# STEP2
-
-gpmetis_run=/mnt/petrelfs/share_data/quxiaoye/metis_for_graph_split/bin/gpmetis
-template1=layers.
-template2=.mlp.up_proj.weight
-
-for layer in {0..39}; do
- OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \
- ${gpmetis_run} ${save_path}/${template1}${layer}${template2} ${num_experts} &
- sleep 0.7
-done
-wait
-# STEP3
-
-template3=.part.${num_experts}
-
-for layer in {0..39}; do
- OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \
- python -m smoe.entrypoint.moefication.llama_split_graph_trans_gp \
- --gpmetised_file_path ${save_path}/${template1}${layer}${template2}${template3} &
- sleep 0.7
-done
-wait
-chmod -R 755 ${save_path} >/dev/null 2>&1
-"""
-
-# Execute the bash commands using Python's subprocess module
-subprocess.run(bash_commands, shell=True, executable="/bin/bash")
diff --git a/scripts/moefication/split/run_split_graph.sh b/scripts/moefication/split/run_split_graph.sh
index 589c762..b892014 100644
--- a/scripts/moefication/split/run_split_graph.sh
+++ b/scripts/moefication/split/run_split_graph.sh
@@ -4,9 +4,9 @@
# llama2_7B llama2_13B llama2_30B llama2_base
llama_size=llama_13B
-num_experts=8 # 8 16
-metric=l2_norm # l1_norm l2_norm plain
-template=layers.{}.mlp.up_proj.weight # gate_proj up_proj
+num_experts=16 # 8 16
+metric=l1_norm # l1_norm l2_norm plain
+proj_type=up_proj # gate_proj up_proj
threshold=1
data_path=/mnt/petrelfs/share_data/quxiaoye
@@ -25,7 +25,7 @@ for specify_layer in {0..39}; do
--model_path ${model_path} \
--save_path ${save_path} \
--specify_layer ${specify_layer} \
- --template ${template} \
+ --template layers.{}.mlp.${proj_type}.weight \
--num_experts ${num_experts} \
--threshold ${threshold} \
--metric ${metric} \
@@ -38,7 +38,7 @@ wait
gpmetis_run=/mnt/petrelfs/share_data/quxiaoye/metis_for_graph_split/bin/gpmetis
template1=layers.
-template2=.mlp.up_proj.weight
+template2=.mlp.${proj_type}.weight
for layer in {0..39}; do
OMP_NUM_THREADS=8 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=auto \
@@ -46,6 +46,7 @@ for layer in {0..39}; do
sleep 0.7
done
wait
+
# STEP3
template3=.part.${num_experts}
@@ -57,4 +58,5 @@ for layer in {0..39}; do
sleep 0.7
done
wait
+
chmod -R 755 ${save_path} >/dev/null 2>&1
diff --git a/scripts/moefication/split/run_split_random.sh b/scripts/moefication/split/run_split_random.sh
index da2485b..aecdc8c 100644
--- a/scripts/moefication/split/run_split_random.sh
+++ b/scripts/moefication/split/run_split_random.sh
@@ -5,8 +5,7 @@
# open_llama_7b
llama_size="open_llama_7b"
-num_experts=8 # 8 16
-proj_type=gate_proj # gate_proj up_proj
+num_experts=8 # 8 16
data_path=/mnt/petrelfs/share_data/quxiaoye
model_path=${data_path}/models/${llama_size}
@@ -18,7 +17,7 @@ OMP_NUM_THREADS=2 srun --partition=MoE --job-name=split --mpi=pmi2 --gres=gpu:${
python -m smoe.entrypoint.moefication.llama_split_random \
--model_path ${model_path} \
--save_path ${save_path} \
- --template layers.{}.mlp.${proj_type}.weight \
+ --template layers.{}.mlp.up_proj.weight \
--num_experts ${num_experts}
chmod -R 755 ${save_path} >/dev/null 2>&1
diff --git a/smoe/entrypoint/moefication/llama_split_random.py b/smoe/entrypoint/moefication/llama_split_random.py
index b9c4146..eb26dad 100644
--- a/smoe/entrypoint/moefication/llama_split_random.py
+++ b/smoe/entrypoint/moefication/llama_split_random.py
@@ -11,7 +11,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default="/home/data/models/llama-transformers/7B")
parser.add_argument('--save_path', type=str, default="/home/dongdz/workspace/moefication/llama_moe_temp_files/")
- parser.add_argument('--template', type=str, default='layers.{}.mlp.gate_proj.weight')
+ parser.add_argument('--template', type=str, default='layers.{}.mlp.up_proj.weight')
parser.add_argument('--num_experts', type=int, default=8, help='number of experts')
args = parser.parse_args()