Skip to content

Commit

Permalink
add the gate function of softmoe
Browse files Browse the repository at this point in the history
  • Loading branch information
JCruan519 committed Sep 19, 2023
1 parent f65a09e commit c71a7f9
Show file tree
Hide file tree
Showing 14 changed files with 755 additions and 50,011 deletions.
1 change: 0 additions & 1 deletion models/124M/encoder.json

This file was deleted.

50,001 changes: 0 additions & 50,001 deletions models/124M/vocab.bpe

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ coverage==7.2.7
datasets==2.14.1
debugpy==1.6.7
deepspeed==0.10.0
einops==0.6.1
flake8==6.0.0
huggingface-hub==0.16.4
isort==5.12.0
Expand Down
16 changes: 16 additions & 0 deletions scripts/examples/create_soft_llama_moe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/bash

# llama_7B llama_13B llama_30B llama_base
# llama2_7B llama2_13B llama2_30B llama2_base
base_model=llama_7B

model_type=LlamaMoEForCausalLM # LlamaMoEModel LlamaMoEForCausalLM LlamaMoEForSequenceClassification

tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/${base_model}/

gpus=1
cpus=16
OMP_NUM_THREADS=8 srun --partition=MoE --job-name=test --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --job-name=example --kill-on-bad-exit=1 \
python -m smoe.entrypoint.examples.create_soft_llama_moe \
--tokenizer_path ${tokenizer_path} \
--model_type ${model_type}
2 changes: 1 addition & 1 deletion scripts/examples/create_switch_llama_moe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model_type=LlamaMoEForCausalLM # LlamaMoEModel LlamaMoEForCausalLM LlamaMoEFo

tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/${base_model}/

gpus=0
gpus=1
cpus=16
OMP_NUM_THREADS=8 srun --partition=MoE --job-name=test --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --job-name=example --kill-on-bad-exit=1 \
python -m smoe.entrypoint.examples.create_switch_llama_moe \
Expand Down
58 changes: 58 additions & 0 deletions scripts/moefication/run_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import subprocess

# Define the bash commands

bash_commands = """
#!/usr/bin/bash
# llama_7B llama_13B llama_30B llama_base
# llama2_7B llama2_13B llama2_30B llama2_base
llama_size="llama2_7B"
num_experts=8 # 8 16
num_selects=2 # 2 4
split_type=Graph-l2_norm # Clustering-l2 Clustering-cos Random Graph-l1_norm Graph-l2_norm
select_type=l2_norm # plain positive l2_norm l1_norm
proj_type=gate_proj # gate_proj up_proj
train_percent=0.95
batch_size=1024
epochs=200
lr=0.01
data_path=/mnt/petrelfs/share_data/quxiaoye
model_path=${data_path}/models/${llama_size}
split_file_path=${data_path}/moefication_results/split/${llama_size}-${num_experts}Expert-Split-${split_type}
hidden_features_path=${data_path}/moefication_results/features/${llama_size}-Hidden-Features
save_path=${data_path}/moefication_results/select/${split_type}
save_visualization_path=/mnt/petrelfs/dongdaize.d/workspace/train-moe/visualization-scheduler-train2/${split_type}-${select_type}/${llama_size}-${num_experts}Select${num_selects}-${proj_type}
gpus=1
cpus=16
for specify_layer in "0 1 2 3" "4 5 6 7" "8 9 10 11" "12 13 14 15" "16 17 18 19" "20 21 22 23" "24 25 26 27" "28 29 30 31"; do # 并行启用任务
OMP_NUM_THREADS=8 srun --partition=MoE --job-name=select --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 \
python -m smoe.entrypoint.moefication.llama_select_mlp \
--model_path ${model_path} \
--split_file_path ${split_file_path} \
--hidden_features_path ${hidden_features_path} \
--save_path ${save_path} \
--save_visualization_path ${save_visualization_path} \
--specify_layer ${specify_layer} \
--template layers.{}.mlp.${proj_type}.weight \
--num_experts ${num_experts} \
--num_selects ${num_selects} \
--select_criterion ${select_type} \
--use_softmax \
--train_percent ${train_percent} \
--batch_size ${batch_size} \
--epochs ${epochs} \
--lr ${lr} & # 并行运行下一命令
sleep 0.5 # 等待0.5s
done
"""

# Execute the bash commands using Python's subprocess module
subprocess.run(bash_commands, shell=True, executable="/bin/bash")
106 changes: 106 additions & 0 deletions smoe/entrypoint/examples/create_soft_llama_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Create a LLaMA MoE model with SwitchBalancedGate.
"""

import argparse

import numpy as np
import torch.cuda
from transformers import LlamaTokenizer

from smoe.models.llama_moefication.configuration_llama_moe import LlamaMoEConfig
from smoe.models.llama_moefication.modeling_llama_moe import (
LlamaMoEForCausalLM,
LlamaMoEForSequenceClassification,
LlamaMoEModel,
)


def main(args):
device = "cuda:0" if torch.cuda.is_available() else "cpu"

"""set up configs"""
# 模型大小参数
intermediate_size = 11008
num_hidden_layers = 32

# MoE专家配置
num_experts = 4
num_selects = 1 # SwitchBalancedGate 的选择数量只能为1
size_experts = [] # 每个专家拥有的神经元数量,如果为None则各个专家大小相同

# MoE门控网络配置
gate_type = "SoftMoEGate"
slots_per_expert = 1

# MoE计算方法配置
calculator_type = "SoftMoECalculator"

# 随机生成各个专家的大小,添加到size_experts
for i in range(num_hidden_layers):
this_size = np.random.randint(
1, high=intermediate_size // num_experts + 1, size=num_experts
)
diff = intermediate_size - np.sum(this_size) # 调整列表中的数字,使总和达到目标值
this_size[-1] += diff
size_experts.append(this_size)
print("size_experts: ", size_experts)

"""create model"""
print("Creating model...")
config_llama_moe = LlamaMoEConfig(
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_experts=num_experts,
num_selects=num_selects,
size_experts=size_experts,
gate_type=gate_type,
slots_per_expert=slots_per_expert,
calculator_type=calculator_type,
)

if args.model_type == "LlamaMoEModel":
model_llama_moe = LlamaMoEModel(config_llama_moe)
elif args.model_type == "LlamaMoEForCausalLM":
model_llama_moe = LlamaMoEForCausalLM(config_llama_moe)
elif args.model_type == "LlamaMoEForSequenceClassification":
model_llama_moe = LlamaMoEForSequenceClassification(config_llama_moe)
else:
raise ValueError

"""prepare data"""
sentence_list = [
"hi hi hi hi hi, hi hi hi hi hi, hi hi hi hi hi",
"How are you? I'm fine, and you?",
"<s> <unk> <unk> <unk> <unk> <unk> </s>",
"I am stupid. Are you sure?",
"The past is never dead. It is not even past.",
]

tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer(sentence_list, padding=True, return_tensors="pt")
print(tokens)

"""forward test"""
print("Forwarding inputs...")
model_llama_moe.to(device).half()
tokens.to(device)
result = model_llama_moe(**tokens) # noqa: F841
# print(result)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_path", type=str)
parser.add_argument(
"--model_type",
type=str,
choices=(
"LlamaMoEModel",
"LlamaMoEForCausalLM",
"LlamaMoEForSequenceClassification",
),
)
args = parser.parse_args()
main(args)
Loading

0 comments on commit c71a7f9

Please sign in to comment.