From a429e8c82f5743a5d64d6dc320dfc078aad3faf9 Mon Sep 17 00:00:00 2001 From: zhutong Date: Wed, 29 Nov 2023 10:43:59 +0800 Subject: [PATCH 01/12] update cpt scripts, add `msg_prefix` in notification, add `gate_balance_loss_weight` in model arguments --- .../baseline_112gpus_linear_gate.sh | 167 +++++++++++++++++ .../baseline_112gpus_sheared_llama_portion.sh | 10 +- ...e_112gpus_sheared_llama_portion_fluency.sh | 169 ++++++++++++++++++ ...ared_llama_portion_gate_balance_loss0.1.sh | 168 +++++++++++++++++ ...ine_112gpus_sheared_llama_portion_no_ad.sh | 169 ++++++++++++++++++ smoe/entrypoint/analysis/gate_load_vis.py | 77 ++++++-- smoe/entrypoint/cpt/cpt_fpt.py | 3 +- smoe/utils/config.py | 6 + smoe/utils/notification.py | 29 ++- tests/data/test_streaming.py | 79 ++++++-- 10 files changed, 834 insertions(+), 43 deletions(-) create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh new file mode 100644 index 0000000..03c65e0 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh @@ -0,0 +1,167 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, one linear layer gate" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --gate_network_type "linear" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh index a74b1d1..fc60f05 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh @@ -79,9 +79,11 @@ source ~/anaconda3/bin/activate smoe echo "global batch size: $global_bs" tokens_per_batch=$(echo "$global_bs * $block_size" | bc) echo "#tokens/batch: $tokens_per_batch" - warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" - eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" data_cache=resources/cache @@ -92,7 +94,7 @@ source ~/anaconda3/bin/activate smoe scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh git diff > $output_dir/diff.patch env > $output_dir/env - echo $comment > $output_dir/comment.txt + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt echo "$SLURM_JOB_ID" > $base_dir/latest.jobid ln -snf $output_dir $base_dir/latest.dir ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log @@ -136,7 +138,7 @@ source ~/anaconda3/bin/activate smoe --learning_rate ${lr} \ --weight_decay 0.1 \ --max_grad_norm 1.0 \ - --warmup_steps 100 \ + --warmup_steps ${warmup_steps} \ --max_steps ${max_steps} \ --max_train_samples ${max_train_samples} \ --save_strategy steps \ diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh new file mode 100644 index 0000000..ae8a38c --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh @@ -0,0 +1,169 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh new file mode 100644 index 0000000..3de43f8 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh @@ -0,0 +1,168 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --gate_balance_loss_weight 0.1 \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh new file mode 100644 index 0000000..4d93400 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh @@ -0,0 +1,169 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --resume_from_checkpoint "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2323339/checkpoint-340" \ + --prob_map "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/smoe/entrypoint/analysis/gate_load_vis.py b/smoe/entrypoint/analysis/gate_load_vis.py index 396d7b0..d94bd66 100644 --- a/smoe/entrypoint/analysis/gate_load_vis.py +++ b/smoe/entrypoint/analysis/gate_load_vis.py @@ -1,6 +1,5 @@ from pathlib import Path -import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import torch @@ -15,10 +14,13 @@ @torch.no_grad() -def main(): +def main( + model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/outputs/cpt-llama2_random_scale4_112gpus-2220221/checkpoint-13600/", + result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/", +): bsz = 8 # model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-16Select4-688Neurons-Share" - model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage1-2090437/checkpoint-4000" + # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage1-2090437/checkpoint-4000" # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage2-2105807/checkpoint-4000" eval_path_map = { "en_wikipedia": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_wikipedia.jsonl", @@ -33,7 +35,7 @@ def main(): "hellaswag": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/hellaswag.jsonl", "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", } - result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" + # result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" result_dir = Path(result_dir) result_dir.mkdir(exist_ok=True, parents=True) @@ -57,7 +59,13 @@ def main(): eval_dataset, batch_size=bsz, collate_fn=fault_tolerance_data_collator ) loader = accel.prepare_data_loader(loader) - for batch in tqdm(loader, desc=name): + if name == "en_book": + num_batch = 20 + else: + num_batch = 9999999999999999 + for batch_idx, batch in enumerate(tqdm(loader, desc=name)): + if batch_idx >= num_batch: + break outs = model(**batch, output_attentions=False, use_cache=False) # gate_load: (tensor([1.0, 2.3, ... num_experts]), tensor([3.0, 4.5, ... num_experts]), ... num_layers) gate_load = outs.gate_load @@ -102,20 +110,26 @@ def heatmap( fig.savefig(save_path, dpi=320, bbox_inches="tight") -def calc_sim(): - gate_load_folder = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" +def calc_sim( + # gate_load_folder = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" + gate_load_folder="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/", + layer_idx=0, + plot=True, +): # title = "SlimPajama" # sim_pairs = [["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], ["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"]] - # title = "Dev vs. SlimPajama" - # sim_pairs = [["arc_challenge", "gsm8k", "hellaswag", "mmlu"], ["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"]] - title = "Dev vs. Dev" + title = "Dev vs. SlimPajama" sim_pairs = [ ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], ] + # title = "Dev vs. Dev" + # sim_pairs = [ + # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + # ] # title = "test" # sim_pairs = [["wiki", "github"], ["wiki", "github"]] - layer_idx = 0 folder = Path(gate_load_folder) name2arr = {} @@ -135,14 +149,41 @@ def calc_sim(): t1_load = name2arr[type1] for t2_idx, type2 in enumerate(sim_pairs[1]): t2_load = name2arr[type2] - # _sim = np.dot(t1_load, t2_load) / (np.linalg.norm(t1_load) * np.linalg.norm(t2_load)) - _sim = 1.0 - np.linalg.norm(t1_load - t2_load, 2) + _sim = np.dot(t1_load, t2_load) / ( + np.linalg.norm(t1_load) * np.linalg.norm(t2_load) + ) + # _sim = 1.0 - np.linalg.norm(t1_load - t2_load, 2) sim_arr[t1_idx][t2_idx] = _sim - heatmap( - sim_arr, sim_pairs[1], sim_pairs[0], str(folder / f"sim_{title}.png"), title - ) + if plot: + heatmap( + sim_arr, + sim_pairs[1], + sim_pairs[0], + str(folder / f"layer{layer_idx}" / f"cos_sim_{title}.png"), + title, + ) + + return sim_arr if __name__ == "__main__": # main() - calc_sim() + + sim_arr_list = [] + for layer_idx in range(32): + sim_arr = calc_sim(layer_idx=layer_idx) + sim_arr_list.append(sim_arr) + sim_arr = np.stack(sim_arr_list, axis=0) + sim_arr = sim_arr.mean(axis=0) + title = "Dev vs. SlimPajama" + sim_pairs = [ + ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], + ] + heatmap( + sim_arr, + sim_pairs[1], + sim_pairs[0], + "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/cos_sim_avg.png", + title, + ) diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index 35a98e8..8dcfe80 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -65,7 +65,7 @@ logger = logging.getLogger(__name__) -@wechat_sender() +@wechat_sender(msg_prefix="CPT Training") def main(): model_args, data_args, training_args = parse_args( ModelArguments, DataArguments, EnhancedTrainingArguments @@ -143,6 +143,7 @@ def main(): "num_selects": model_args.num_selects, "gate_network": model_args.gate_network_type, "score_scale_factor": model_args.moe_calculator_score_scale_factor, + "gate_balance_loss_weight": model_args.gate_balance_loss_weight, } ConfigClass = AutoConfig if model_args.config_name == "llama_moe" or model_args.model_type == "llama_moe": diff --git a/smoe/utils/config.py b/smoe/utils/config.py index abebac6..7a0fabb 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -152,6 +152,12 @@ class ModelArguments: num_selects: int = field( default=4, metadata={"help": "The number of experts to be selected"} ) + gate_balance_loss_weight: float = field( + default=1e-2, + metadata={ + "help": "The weight of the balance loss for the gate, should be a float" + }, + ) def __post_init__(self): if self.config_overrides is not None and ( diff --git a/smoe/utils/notification.py b/smoe/utils/notification.py index 5b5a6af..d0cca39 100644 --- a/smoe/utils/notification.py +++ b/smoe/utils/notification.py @@ -23,10 +23,37 @@ def get_slurm_job_name(): return f"{job_name}-{job_id}" +def send_to_wechat( + msg: str, + webhook_url: str = None, + user_mentions: list[str] = None, + user_mentions_mobile: list[str] = None, +): + if not webhook_url: + webhook_url = os.environ.get("WECHAT_ROBOT_WEBHOOK") + if not user_mentions: + env_user_mentions = os.environ.get("WECHAT_ROBOT_MENTIONS", "") + user_mentions = env_user_mentions.split(",") + if not user_mentions_mobile: + env_user_mentions_mobile = os.environ.get("WECHAT_ROBOT_MENTIONS_MOBILE", "") + user_mentions_mobile = env_user_mentions_mobile.split(",") + + msg_template = { + "msgtype": "text", + "text": { + "content": msg, + "mentioned_list": user_mentions, + "mentioned_mobile_list": user_mentions_mobile, + }, + } + requests.post(webhook_url, json=msg_template) + + def wechat_sender( webhook_url: str = None, user_mentions: list[str] = [], user_mentions_mobile: list[str] = [], + msg_prefix: str = "", ): """ WeChat Work sender wrapper: execute func, send a WeChat Work notification with the end status @@ -60,7 +87,7 @@ def wechat_sender( msg_template = { "msgtype": "text", "text": { - "content": "", + "content": f"{msg_prefix}", "mentioned_list": user_mentions, "mentioned_mobile_list": user_mentions_mobile, }, diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 2e71d81..8bc2da2 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -138,24 +138,64 @@ def test_weighted_streaming_loader(): print(type(loader)) print(loader.sampler, type(loader.sampler)) - # for batch_idx, batch in enumerate(loader): - # if batch_idx == 0: - # print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") - # if num_test_case <= 0: - # break - # assert len(batch["input_ids"]) == bsz - # # print( - # # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - # # ) - # # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size - # print(loader.dataset.prob_map) - # num_test_case -= 1 - # lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) - # # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) - # print(loader.dataset.prob_map) - # # print( - # # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - # # ) + for batch_idx, batch in enumerate(loader): + if batch_idx == 0: + print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") + if num_test_case <= 0: + break + assert len(batch["input_ids"]) == bsz + # print( + # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # ) + # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size + print(loader.dataset.prob_map) + num_test_case -= 1 + lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + print(loader.dataset.prob_map) + # print( + # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # ) + + +def test_linked_dataset(): + from accelerate import Accelerator + + from smoe.data.dynamic_selection import AVERAGE_SLIMPAJAMA_DATA_PORTION + + ac = Accelerator() + + # folder_path = "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed" + folder_path = "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg" + + num_test_case = 20000 + block_size = 2048 + bsz = 1 + + lm_datasets = SubDirWeightedPackedJsonlDataset( + folder_path, + prob_map=AVERAGE_SLIMPAJAMA_DATA_PORTION, + seed=1227, + block_size=block_size, + ) + loader = DataLoader( + lm_datasets, + batch_size=bsz, + num_workers=0, + collate_fn=fault_tolerance_data_collator, + pin_memory=False, + ) + loader = ac.prepare_data_loader(loader) + print(type(loader)) + print(loader.sampler, type(loader.sampler)) + + for batch_idx, batch in enumerate(loader): + if batch_idx == 0: + print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") + if num_test_case <= 0: + break + assert len(batch["input_ids"]) == bsz + num_test_case -= 1 def test_skip_tokens(): @@ -166,4 +206,5 @@ def test_skip_tokens(): # test_jsonl_dataset() # test_subdir_weighted_pack_with_type() # test_weighted_streaming() - test_weighted_streaming_loader() + # test_weighted_streaming_loader() + test_linked_dataset() From 422d58e7d2570a4a0c6302165512675644ba79de Mon Sep 17 00:00:00 2001 From: zhutong Date: Wed, 29 Nov 2023 10:45:37 +0800 Subject: [PATCH 02/12] update reqs --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 125c8fa..98417b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ Pillow==9.4.0 numpy==1.25.0 opencv-python==4.8.1.78 pynvml==11.5.0 +PyYaml==6.0.1 From f0ef584284b8a8333eae5929008980302756a552 Mon Sep 17 00:00:00 2001 From: zhutong Date: Fri, 1 Dec 2023 00:28:43 +0800 Subject: [PATCH 03/12] update wechat notification to support msg prefix, update vis --- .vscode/launch.json | 2 +- smoe/entrypoint/analysis/gate_load_vis.py | 142 +++++- .../analysis/hidden_before_gate_vis.py | 436 ++++++++++++++++++ smoe/utils/notification.py | 14 +- 4 files changed, 573 insertions(+), 21 deletions(-) create mode 100644 smoe/entrypoint/analysis/hidden_before_gate_vis.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 38c3d90..b2bde74 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-123", + "host": "SH-IDCA1404-10-140-54-12", "port": 5678 }, "pathMappings": [ diff --git a/smoe/entrypoint/analysis/gate_load_vis.py b/smoe/entrypoint/analysis/gate_load_vis.py index d94bd66..692ea08 100644 --- a/smoe/entrypoint/analysis/gate_load_vis.py +++ b/smoe/entrypoint/analysis/gate_load_vis.py @@ -18,7 +18,7 @@ def main( model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus/outputs/cpt-llama2_random_scale4_112gpus-2220221/checkpoint-13600/", result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/", ): - bsz = 8 + bsz = 4 # model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-16Select4-688Neurons-Share" # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage1-2090437/checkpoint-4000" # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage2-2105807/checkpoint-4000" @@ -31,9 +31,9 @@ def main( "en_book": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_book.jsonl", "en_arxiv": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_arxiv.jsonl", "arc_challenge": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/arc_challenge.jsonl", - "gsm8k": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/gsm8k.jsonl", + "gsm8k": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/gsm8k.jsonl", # 37998 tokens "hellaswag": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/hellaswag.jsonl", - "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", + "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", # 23720 tokens } # result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" @@ -63,6 +63,7 @@ def main( num_batch = 20 else: num_batch = 9999999999999999 + num_batch = 1 for batch_idx, batch in enumerate(tqdm(loader, desc=name)): if batch_idx >= num_batch: break @@ -99,7 +100,15 @@ def heatmap( for i in range(shape[0]): for j in range(shape[1]): - ax.text(j, i, f"{arr[i, j]:.1%}", ha="center", va="center", color="black") + text = ax.text( + j, + i, + f"{arr[i, j]:.1%}", + ha="center", + va="center", + color="black", + fontsize=6, + ) ax.set_xticks(range(len(xlabels))) ax.set_yticks(range(len(ylabels))) ax.set_xticklabels(xlabels, rotation=45, ha="right") @@ -118,11 +127,11 @@ def calc_sim( ): # title = "SlimPajama" # sim_pairs = [["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], ["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"]] - title = "Dev vs. SlimPajama" - sim_pairs = [ - ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], - ] + # title = "Dev vs. SlimPajama" + # sim_pairs = [ + # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + # ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], + # ] # title = "Dev vs. Dev" # sim_pairs = [ # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], @@ -130,6 +139,35 @@ def calc_sim( # ] # title = "test" # sim_pairs = [["wiki", "github"], ["wiki", "github"]] + title = f"Routing Similarity Layer {layer_idx}" + sim_pairs = [ + [ + "arc_challenge", + "gsm8k", + "hellaswag", + "mmlu", + "en_wikipedia", + "github", + "en_stack", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + ], + [ + "arc_challenge", + "gsm8k", + "hellaswag", + "mmlu", + "en_wikipedia", + "github", + "en_stack", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + ], + ] folder = Path(gate_load_folder) name2arr = {} @@ -166,24 +204,96 @@ def calc_sim( return sim_arr -if __name__ == "__main__": - # main() +def gate_load_vis(): + model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2326233/checkpoint-5440" + result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_80B_gate_load/" + main( + # w/ fluency filtering, 85b + model_dir=model_dir, + result_dir=result_dir, + ) sim_arr_list = [] for layer_idx in range(32): - sim_arr = calc_sim(layer_idx=layer_idx) + sim_arr = calc_sim( + gate_load_folder=result_dir, + layer_idx=layer_idx, + ) sim_arr_list.append(sim_arr) sim_arr = np.stack(sim_arr_list, axis=0) sim_arr = sim_arr.mean(axis=0) - title = "Dev vs. SlimPajama" + # title = "Dev vs. SlimPajama" + # sim_pairs = [ + # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + # ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], + # ] + title = "Routing Similarity" sim_pairs = [ - ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], + [ + "arc_challenge", + "gsm8k", + "hellaswag", + "mmlu", + "en_wikipedia", + "github", + "en_stack", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + ], + [ + "arc_challenge", + "gsm8k", + "hellaswag", + "mmlu", + "en_wikipedia", + "github", + "en_stack", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + ], ] heatmap( sim_arr, sim_pairs[1], sim_pairs[0], - "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/cos_sim_avg.png", + f"{result_dir}/cos_sim_avg.png", title, ) + + +def gate_load_vis_from_cache(name, cache_filepath, result_dir, minmax: bool = False): + gate_load_sum = np.load(cache_filepath) + if minmax: + gate_load_sum = (gate_load_sum - gate_load_sum.min()) / ( + gate_load_sum.max() - gate_load_sum.min() + ) + for layer_idx in range(gate_load_sum.shape[0]): + visualize_expert_load_heatmap( + gate_load_sum[layer_idx], + layer_idx, + name, + shape=(4, 4), + save_dir=str(result_dir), + save_fig=True, + ) + + +if __name__ == "__main__": + main( + model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2326233/checkpoint-6120", + result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_90B_gate_load/", + ) + + # gate_load_vis() + + # for name in ["gsm8k", "mmlu"]: + # gate_load_vis_from_cache( + # name, + # f"results/llama2_7B_random_split_sheared_sampling_fluency_85B_gate_load/{name}_gate_load.npy", + # f"results/llama2_7B_random_split_sheared_sampling_fluency_85B_gate_load/{name}", + # minmax=True, + # ) diff --git a/smoe/entrypoint/analysis/hidden_before_gate_vis.py b/smoe/entrypoint/analysis/hidden_before_gate_vis.py new file mode 100644 index 0000000..195e50c --- /dev/null +++ b/smoe/entrypoint/analysis/hidden_before_gate_vis.py @@ -0,0 +1,436 @@ +import warnings +from pathlib import Path +from types import MethodType + +import matplotlib.pyplot as plt +import numpy as np +import torch +from accelerate import Accelerator +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +from torch.utils.data import DataLoader +from tqdm import tqdm, trange + +from smoe.data.collate_fn import fault_tolerance_data_collator +from smoe.data.streaming import CachedJsonlDataset +from smoe.models.llama_moe import LlamaMoEForCausalLM +from smoe.models.llama_moe.modeling_llama_moe import ( + LlamaMoEDecoderLayer, + MoEDecoderLayerOutput, + MoEMlpOutput, +) +from smoe.modules.moe.moe_gates import TopKBalancedNoisyGate + +eval_path_map = { + "en_wikipedia": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_wikipedia.jsonl", + "github": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/github.jsonl", + "en_stack": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_stack.jsonl", + "en_cc": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_cc.jsonl", + "en_c4": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_c4.jsonl", + "en_book": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_book.jsonl", + "en_arxiv": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/en_arxiv.jsonl", + "arc_challenge": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/arc_challenge.jsonl", + "gsm8k": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/gsm8k.jsonl", + "hellaswag": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/hellaswag.jsonl", + "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", +} +hidden_list = [] + + +def hidden_recording_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + output_attentions=False, + use_cache=False, +) -> MoEDecoderLayerOutput: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_list.append(hidden_states.detach().cpu()) + + mlp_outs: MoEMlpOutput = self.mlp(hidden_states) + hidden_states = residual + mlp_outs.hidden_states + + outputs = ( + hidden_states, + mlp_outs.balance_loss, + mlp_outs.num_dropped_tokens, + mlp_outs.gate_load, + mlp_outs.gate_importance, + ) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + for i, _o in enumerate(outputs): + if not isinstance(_o, torch.Tensor): + raise RuntimeError( + f"outputs[{i}]({type(_o)}) should be torch.Tensor to support grad ckpt" + ) + + return outputs + + +softmax_list = [] + + +def gate_recording_forward(self, x): + """先计算所有专家的权重值""" + logits_gate = self.gate_network(x) # gate计算出的权重 + logits = logits_gate # 最终权重,shape(batch_size, num_experts) + + """选出前k个权重,并计算各个专家的分数scores""" + top_logits, top_indices = logits.topk( + min(self.num_selects + 1, self.num_experts), dim=1 + ) # 选择并排序前k+1个权重 + top_k_logits = top_logits[:, : self.num_selects] + top_k_indices = top_indices[:, : self.num_selects] + top_k_scores = self.softmax(top_k_logits) + + """计算importance""" + zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) + scores_filtered = zeros.scatter( + dim=1, index=top_k_indices, src=top_k_scores + ) # shape(batch_size, num_experts) + softmax_list.append(scores_filtered.detach().cpu()) + importance = scores_filtered.sum(0) # shape(num_experts) + + """计算load""" + load = (scores_filtered > 0).sum(0) + + """计算balance loss""" + if self.use_balance: + balance_loss = self.cv_squared(importance) + self.cv_squared(load) + balance_loss *= self.balance_loss_weight + else: + balance_loss = torch.tensor(-100.0, device=x.device) + + return { + "topK_indices": top_k_indices, + "topK_scores": top_k_scores, + "balance_loss": balance_loss, + "load": load, + "importance": importance, + } + + +def plot_2d_distribution(embs, labels, save_path, title: str = None): + fig = plt.figure() + ax = fig.add_subplot(111) + for emb, label in zip(embs, labels): + ax.scatter(emb[:, 0], emb[:, 1], alpha=0.2, label=label, s=2) + if title is not None: + ax.set_title(title) + plt.legend() + plt.tight_layout() + plt.savefig(save_path) + + +def tsne_for_one_layer(data_list, save_path, labels, title: str = None): + # data_list: (num datasets, num tokens, hidden dim) + tsne = TSNE(n_components=2, verbose=1) + min_num = min([len(data) for data in data_list]) + + embs = [] + for data in data_list: + emb = tsne.fit_transform(data[:min_num]) + embs.append(emb) + plot_2d_distribution(embs, labels, save_path, title=title) + + +def tsne_for_layers(data_list, save_dir, labels): + num_layers = len(data_list[0]) + for layer_idx in trange(num_layers, desc="Making scatter plots"): + tsne_for_one_layer( + [d[layer_idx] for d in data_list], + f"{save_dir}/tsne_L{layer_idx}.png", + labels, + title=f"t-SNE Layer {layer_idx}", + ) + + +def pca_for_one_layer(data_list, save_path, labels, title: str = None): + # data_list: (num datasets, num tokens, hidden dim) + pca = PCA(n_components=2) + min_num = min([len(data) for data in data_list]) + + embs = [] + for data in data_list: + emb = pca.fit_transform(data[:min_num]) + embs.append(emb) + plot_2d_distribution(embs, labels, save_path, title=title) + + +def pca_for_layers(data_list, save_dir, labels): + num_layers = len(data_list[0]) + for layer_idx in trange(num_layers, "Making scatter plots"): + pca_for_one_layer( + [d[layer_idx] for d in data_list], + f"{save_dir}/pca_L{layer_idx}.png", + labels, + title=f"PCA Layer {layer_idx}", + ) + + +@torch.no_grad() +def main( + model_dir: str, result_dir: str, eval_datanames: list[str], load_cache: bool = False +): + name2hidden = {} + name2softmax = {} + if load_cache: + for name in eval_datanames: + name2hidden[name] = np.load(f"{result_dir}/{name}_hidden.npy") + name2softmax[name] = np.load(f"{result_dir}/{name}_softmax.npy") + else: + global hidden_list + global softmax_list + bsz = 8 + result_dir = Path(result_dir) + result_dir.mkdir(exist_ok=True, parents=True) + + accel = Accelerator() + model = LlamaMoEForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + for module in model.modules(): + if isinstance(module, LlamaMoEDecoderLayer): + module.forward = MethodType(hidden_recording_forward, module) + if isinstance(module, TopKBalancedNoisyGate): + module.forward = MethodType(gate_recording_forward, module) + + model.eval() + model = accel.prepare_model(model) + + eval_dataset = { + name: CachedJsonlDataset(eval_path_map[name], seed=1227, block_size=4096) + for name in eval_datanames + } + for name, eval_dataset in eval_dataset.items(): + hidden_list = [] + tot_hidden = [] + softmax_list = [] + tot_softmax = [] + loader = DataLoader( + eval_dataset, batch_size=bsz, collate_fn=fault_tolerance_data_collator + ) + loader = accel.prepare_data_loader(loader) + if name == "en_book": + num_batch = 20 + else: + num_batch = 9999999999999999 + for batch_idx, batch in enumerate(tqdm(loader, desc=name)): + if batch_idx >= num_batch: + break + model(**batch, output_attentions=False, use_cache=False) + _tmp_batch_hidden = torch.stack(hidden_list, dim=0) + # (num layers, num tokens, hidden dim) + _tmp_batch_hidden = _tmp_batch_hidden.reshape( + len(hidden_list), -1, _tmp_batch_hidden.shape[-1] + ) + tot_hidden.append(_tmp_batch_hidden.detach().cpu().float().numpy()) + _tmp_batch_softmax = torch.stack(softmax_list, dim=0) + # (num layers, num tokens, num experts) + _tmp_batch_softmax = _tmp_batch_softmax.reshape( + len(softmax_list), -1, _tmp_batch_softmax.shape[-1] + ) + tot_softmax.append(_tmp_batch_softmax.detach().cpu().float().numpy()) + hidden_list = [] + softmax_list = [] + # (num layers, num tokens across all batches, hidden dim) + tot_hidden = np.concatenate(tot_hidden, axis=1) + # (num layers, num tokens across all batches, expert num) + tot_softmax = np.concatenate(tot_softmax, axis=1) + np.save(result_dir / f"{name}_hidden.npy", tot_hidden) + np.save(result_dir / f"{name}_softmax.npy", tot_softmax) + name2hidden[name] = tot_hidden + name2softmax[name] = tot_softmax + + # data_list = [] + # labels = [] + # for name, hidden in name2hidden.items(): + # data_list.append(hidden) + # labels.append(name) + # tsne_for_layers(data_list, result_dir, labels=labels) + # pca_for_layers(data_list, result_dir, labels=labels) + + +def heatmap(arr: np.ndarray, save_path: str, title: str, vmin: float, vmax: float): + shape = arr.shape + fig = plt.figure() + ax = fig.add_subplot(111) + im = ax.imshow(arr, cmap="OrRd", interpolation="nearest", vmin=vmin, vmax=vmax) + for row in range(shape[0]): + for col in range(shape[1]): + ax.text( + col, + row, + f"{arr[row, col]:.4f}", + ha="center", + va="center", + color="black", + ) + ax.set_axis_off() + ax.set_title(title) + fig.colorbar(im) + fig.tight_layout() + fig.savefig(save_path) + + +def dual_heatmap(arr1, arr2, save_path, layer_idx: int): + shape = arr1.shape + assert arr1.shape == arr2.shape + vmin = min(arr1.min(), arr2.min()) + vmax = max(arr2.max(), arr2.max()) + fig = plt.figure() + ax1 = fig.add_subplot(121) + im1 = ax1.imshow(arr1, cmap="OrRd", interpolation="nearest", vmin=vmin, vmax=vmax) + for row in range(shape[0]): + for col in range(shape[1]): + ax1.text( + col, + row, + f"{arr1[row, col]:.4f}", + ha="center", + va="center", + color="black", + ) + ax1.set_title("GSM8K") + ax2 = fig.add_subplot(122) + im2 = ax2.imshow(arr2, cmap="OrRd", interpolation="nearest", vmin=vmin, vmax=vmax) + for row in range(shape[0]): + for col in range(shape[1]): + ax2.text( + col, + row, + f"{arr2[row, col]:.4f}", + ha="center", + va="center", + color="black", + ) + ax2.set_title("MMLU") + ax1.set_axis_off() + ax2.set_axis_off() + fig.suptitle(f"Mean Routing Prob Layer {layer_idx}") + fig.tight_layout() + + fig.savefig(save_path) + + +def dual_hist( + arr1, arr2, save_path, layer_idx: int, xlim: list = None, ylim: list = None +): + fig = plt.figure() + ax = fig.add_subplot(111) + ax.hist(arr1.flatten(), bins=100, label="GSM8K", alpha=0.5) + ax.hist(arr2.flatten(), bins=100, label="MMLU", alpha=0.5) + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + ax.legend() + ax.set_title(f"Mean Routing Prob Layer {layer_idx}") + fig.tight_layout() + + fig.savefig(save_path) + + +def softmax_vis(name, cache_filepath, save_dir, vmin, vmax): + # (num layers, num tokens across all batches, expert num) + Path(save_dir).mkdir(exist_ok=True, parents=True) + vals = np.load(cache_filepath) + for layer_idx, layer_vals in enumerate(vals): + val = layer_vals.mean(axis=0) + val = val.reshape(4, 4) + heatmap( + val, + f"{save_dir}/softmax_L{layer_idx}.png", + f"{name} Routing Mean Prob Layer {layer_idx}", + vmin, + vmax, + ) + + +if __name__ == "__main__": + # w/ fluency filtering, 90b + model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2326233/checkpoint-6120" + result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_90B_hidden_dist_vis" + + Path(result_dir).mkdir(exist_ok=True, parents=True) + + # ---- hidden hist ---- + Path(result_dir).joinpath("hidden_hist").mkdir(exist_ok=True, parents=True) + vals1 = np.load(f"{result_dir}/gsm8k_hidden.npy") + vals2 = np.load(f"{result_dir}/mmlu_hidden.npy") + for layer_idx, (layer_vals1, layer_vals2) in enumerate(zip(vals1, vals2)): + num_tokens1 = layer_vals1.shape[0] + num_tokens2 = layer_vals2.shape[0] + limit = min(num_tokens1, num_tokens2) + layer_vals1 = layer_vals1[:limit] + layer_vals2 = layer_vals2[:limit] + assert layer_vals1.shape == layer_vals2.shape + + dual_hist( + layer_vals1, + layer_vals2, + f"{result_dir}/hidden_hist/dual_hidden_L{layer_idx}.png", + layer_idx, + xlim=(-2, 2), + ) + + # # ---- gate softmax values ---- + # Path(result_dir).joinpath("dual").mkdir(exist_ok=True, parents=True) + # Path(result_dir).joinpath("dual_hist").mkdir(exist_ok=True, parents=True) + # vals1 = np.load(f"{result_dir}/gsm8k_softmax.npy") + # vals2 = np.load(f"{result_dir}/mmlu_softmax.npy") + # for layer_idx, (layer_vals1, layer_vals2) in enumerate(zip(vals1, vals2)): + # num_tokens1 = layer_vals1.shape[0] + # num_tokens2 = layer_vals2.shape[0] + # limit = min(num_tokens1, num_tokens2) + # layer_vals1 = layer_vals1[:limit] + # layer_vals2 = layer_vals2[:limit] + # assert layer_vals1.shape == layer_vals2.shape + + # val1 = layer_vals1.mean(axis=0) + # val2 = layer_vals2.mean(axis=0) + # val1 = val1.reshape(4, 4) + # val2 = val2.reshape(4, 4) + # dual_heatmap( + # val1, + # val2, + # f"{result_dir}/dual/dual_softmax_L{layer_idx}.png", + # layer_idx, + # ) + + # val1_ind = np.argsort(layer_vals1, axis=1) + # top4_ind = val1_ind[:, -4:] + # val1 = np.take_along_axis(layer_vals1, top4_ind, axis=1) + # val2_ind = np.argsort(layer_vals2, axis=1) + # top4_ind = val2_ind[:, -4:] + # val2 = np.take_along_axis(layer_vals2, top4_ind, axis=1) + # dual_hist( + # val1, + # val2, + # f"{result_dir}/dual_hist/dual_softmax_L{layer_idx}.png", + # layer_idx, + # ) diff --git a/smoe/utils/notification.py b/smoe/utils/notification.py index d0cca39..050e8d7 100644 --- a/smoe/utils/notification.py +++ b/smoe/utils/notification.py @@ -87,7 +87,7 @@ def wechat_sender( msg_template = { "msgtype": "text", "text": { - "content": f"{msg_prefix}", + "content": "", "mentioned_list": user_mentions, "mentioned_mobile_list": user_mentions_mobile, }, @@ -120,7 +120,9 @@ def wrapper_sender(*args, **kwargs): "Starting date: %s" % start_time.strftime(DATE_FORMAT), ] - msg_template["text"]["content"] = "\n".join(contents) + msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( + contents + ) logger.info(f"{json.dumps(msg_template, ensure_ascii=False)}") if webhook_url: requests.post(webhook_url, json=msg_template) @@ -150,7 +152,9 @@ def wrapper_sender(*args, **kwargs): % "ERROR - Couldn't str the returned value." ) - msg_template["text"]["content"] = "\n".join(contents) + msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( + contents + ) logger.info(f"{json.dumps(msg_template, ensure_ascii=False)}") if webhook_url: requests.post(webhook_url, json=msg_template) @@ -174,7 +178,9 @@ def wrapper_sender(*args, **kwargs): "%s" % traceback.format_exc(), ] - msg_template["text"]["content"] = "\n".join(contents) + msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( + contents + ) logger.info(f"{json.dumps(msg_template, ensure_ascii=False)}") if webhook_url: requests.post(webhook_url, json=msg_template) From 53217dd9294f6655b16c5ae4e43c7443b12d3304 Mon Sep 17 00:00:00 2001 From: zhutong Date: Wed, 13 Dec 2023 23:46:26 +0800 Subject: [PATCH 04/12] update gate load vis and docs --- .gitattributes | 1 + README.md | 130 ++++++-- docs/Contribution.md | 11 + docs/Installation.md | 4 +- docs/continual_pretraining/README.md | 55 ++-- docs/imgs/title-favicon.png | 3 + example.py | 17 ++ ...2gpus_sheared_llama_portion_fluency_sf4.sh | 171 +++++++++++ ...2gpus_sheared_llama_portion_fluency_sf8.sh | 172 +++++++++++ smoe/entrypoint/analysis/gate_load_vis.py | 282 +++++++++++++----- smoe/utils/param_estimation.py | 72 ++++- smoe/utils/tokenize.py | 2 +- smoe/utils/visualization/visualize.py | 6 +- 13 files changed, 810 insertions(+), 116 deletions(-) create mode 100644 .gitattributes create mode 100644 docs/Contribution.md create mode 100644 docs/imgs/title-favicon.png create mode 100644 example.py create mode 100644 scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh create mode 100644 scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8dc584d --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +docs/imgs/title-favicon.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 261cc03..4445f67 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,127 @@ -# train-moe +
+

LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training

+ LLaMA-MoE favicon
+ 📢 A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!! + +
-[[Installation Guide]](docs/Installation.md) | [[MoEfication Docs]](docs/moefication/README.md) | [[Continual Pre-training Docs]](docs/continual_pretraining/README.md) +

🎉 Introduction

-## 🌴 Dependencies +LLaMA-MoE is a series of Mixture-of-Expert (MoE) models based on [LLaMA](https://github.com/facebookresearch/llama). +We build LLaMA-MoE with the following two steps: +1. Partition LLaMA's FFNs into sparse experts and insert top-K gate for each layer of experts. +2. Continually pre-train the initialized MoE model with an optimized data sampling weights from [Sheared LLaMA](https://arxiv.org/abs/2310.06694) and filtered datasets from [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama). -- Python==3.11.4 - - Packages: please check `requirements.txt` (NOTE: `flash-attn` must be properly installed by following [their instructions](https://github.com/Dao-AILab/flash-attention)) -## 🚀 QuickStart +| Model | \#Activated Experts | \#Experts | \#Activated Params | \#Total Prams | Links | +| :----------------- | :-----------------: | :-------: | :----------------: | :-----------: | :----------------------------------------------------------------------------------------------: | +| OPT-2.7B | - | - | 2.7B | 2.7B | ([Zhang et al., 2022](https://huggingface.co/facebook/opt-2.7b)) | +| Pythia-2.8B | - | - | 2.8B | 2.8B | ([Biderman et al., 2023](https://huggingface.co/EleutherAI/pythia-2.8b)) | +| INCITE-BASE-3B | - | - | 2.8B | 2.8B | ([Together Computer, 2023](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1)) | +| Open-LLaMA-3B-v2 | - | - | 3.4B | 3.4B | ([Geng et al., 2023](https://huggingface.co/openlm-research/open_llama_3b_v2)) | +| Sheared-LLaMA-2.7B | - | - | 2.7B | 2.7B | ([Xia et al., 2023](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B)) | +| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | 6.7B | [[HF Weights]](https://huggingface.co) | +| **LLaMA-MoE-3.5B** | 4 | 16 | 3.5B | 6.7B | [[HF Weights]](https://huggingface.co) | + + + +| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | +| :----------------- | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :-------: | +| OPT-2.7B | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 55.9 | 10.7 | 25.8 | 49.6 | +| Pythia-2.8B | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 54.4 | 8.6 | 26.8 | 50.6 | +| INCITE-BASE-3B | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 55.6 | 15.2 | 27.2 | 52.8 | +| Open-LLaMA-3B-v2 | **88.0** | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 59.5 | 16.0 | 26.8 | 54.9 | +| Sheared-LLaMA-2.7B | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 59.7 | 17.7 | **27.3** | 55.6 | +| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 59.7 | 17.0 | 26.8 | 54.8 | +| **LLaMA-MoE-3.5B** | 87.6 | **77.9** | **65.5** | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **63.2** | **20.3** | 26.8 | **57.2 ** | + + +

🚀 QuickStart

+ +```python +import torch +from transformers import AutoTokenizer +from smoe.models.llama_moe import LlamaMoEForCausalLM + + +model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2342244/checkpoint-13600/" +tokenizer = AutoTokenizer.from_pretrained(model_dir) +model = LlamaMoEForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) +model.to("cuda:0") + +input_text = "Suzhou is famous of" +inputs = tokenizer(input_text, return_tensors="pt") +inputs = inputs.to("cuda:0") + +pred = model.generate(**inputs, max_length=50, temperature=0.0) +print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) +# Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three +``` + +

🚧 Expert Initialization

+ +- Neuron-Independent + - IndependentRandom: `bash ./scripts/moefication/split/run_split_random.sh` + - IndependentClustering: `bash ./scripts/moefication/split/run_split_clustering.sh` +- Neuron-Sharing + - SharingInner: `bash ./scripts/moefication/split/run_split_gradient.sh` + - SharingInter: `bash ./scripts/moefication/split/run_split_gradient_residual.sh` + +For more information, please refer to [Expert Initialization docs](docs/moefication/README.md). + +

🚅 Continual Pre-training

+ ### Tokenization -- RedPajama: `bash scripts/tokenize/redpajama.sh` (Don't forget to change the folder paths.) +Download [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama) into `/path_to_data` and put data from different domains into separate folders: + - `/path_to_data/en_arxiv` + - `/path_to_data/en_book` + - `/path_to_data/en_c4` + - `/path_to_data/en_cc` + - `/path_to_data/en_stack` + - `/path_to_data/en_wikipedia` + - `/path_to_data/github` + +Each file should be end with `*.jsonl` and each line looks like: +``` +{"id": "id-info", "content": "raw text to be tokenized"} +``` + +Run the following command to tokenize the data in each folder: + +```bash +python -m smoe.utils.tokenize \ + -f jsonl \ + -t /path_to_tokenizer \ + -i /path_to_data/en_arxiv \ + -o /path_to_data_tokenized/en_arxiv +``` ### Continual Pre-training (CPT) -**NOTICE:** Please create `logs/` folder manually: `mkdir -p logs` +- **NOTICE:** Please create `logs/` folder manually: `mkdir -p logs` +- To run the continual pre-training, please check the [CPT docs](docs/continual_pretraining/README.md). -- LLaMA MoEfication LoRA: `sbatch scripts/cpt/lora.sh` -- LLaMA MoEfication Full-Parameter: `sbatch scripts/cpt/fpt.sh` +

💎 Evaluation

-## 🤝 Contribution +- For evalution on Natural Questions (NQ), please refer to [opencompass](https://github.com/Spico197/opencompass/tree/main). +- For other tasks, please refer to [lm-eval-harness](https://github.com/spico197/smoe-eval). -- Make sure the Python version `>=3.10` (a strict version contraint for better type hinting) +

📑 Citation

-```bash -$ conda install git # upgrade git -$ git clone git@github.com:pjlab-sys4nlp/train-moe.git -$ cd train-moe -$ pip install -e .[dev] -$ pre-commit install +```bibtex +@article{llama-moe-2023, + title={LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training}, + author={LLaMA-MoE Team}, + journal={arXiv preprint arXiv:}, + url={https://arxiv.org/abs/}, + year={2023} +} ``` + +
+

LLaMA-MoE Team w/ ❤️

diff --git a/docs/Contribution.md b/docs/Contribution.md new file mode 100644 index 0000000..885a76e --- /dev/null +++ b/docs/Contribution.md @@ -0,0 +1,11 @@ +# 🤝 Contribution + +- Make sure the Python version `>=3.10` (a strict version contraint for better type hinting) + +```bash +$ conda install git # upgrade git +$ git clone git@github.com:pjlab-sys4nlp/llama-moe.git +$ cd llama-moe +$ pip install -e .[dev] +$ pre-commit install +``` diff --git a/docs/Installation.md b/docs/Installation.md index 8dae4e9..06dc814 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -1,7 +1,7 @@ # 🌴 Installation 1. Prepare conda environment: `conda create -n smoe python=3.11` (If your environment name is not `smoe`, you may need to change environment in launching scripts) -2. Add environment variables in `~/.bashrc` (`gcc` is set to newer version for installing `flash-attn`): +2. Add correct environment variables in `~/.bashrc` (`gcc` is set to newer version for installing `flash-attn`). e.g.: ```bash export PATH=/mnt/petrelfs/share/cuda-11.8/bin:$PATH export LD_LIBRARY_PATH=/mnt/petrelfs/share/cuda-11.8/lib64:$LD_LIBRARY_PATH @@ -11,7 +11,7 @@ 3. Take the variables into effect: `source ~/.bashrc` 4. Install PyTorch (CUDA-11.8): `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118` 5. Install dependencies: `pip install -r requirements.txt` -6. Install `flash-attn`: `pip install flash-attn==2.0.1 --no-build-isolation` +6. Install `flash-attn`: `pip install flash-attn==2.0.1 --no-build-isolation`. You may need to follow the [flash-attn installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) to avoid some errors. 7. Install the latest Git: `conda install git` 8. Clone the repo: `git clone git@github.com:pjlab-sys4nlp/train-moe.git` (If you don't setup the ssh key to GitHub, you may not able to clone through ssh. Check the [docs](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account) about it.) 9. Change current directory: `cd train-moe` diff --git a/docs/continual_pretraining/README.md b/docs/continual_pretraining/README.md index 56a5448..6e5e441 100644 --- a/docs/continual_pretraining/README.md +++ b/docs/continual_pretraining/README.md @@ -1,6 +1,38 @@ # 🚅 Training Guide -## ⚙️ Configuration Instructions +## 🗞️ Executive Scripts + +| Description | Path | +| :------------------------ | :------------------------------------------------------------------------------------- | +| LLaMA-MoE 2/16 Experts | `scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh` | +| LLaMA-MoE 4/16 Experts | `scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh` | +| DynamicSheared | `scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh` | + +## 🌴 Other Arguments in Executive Scripts + +| Argument Name | Description | +| :------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--dynamic_data_selection` | For different dynamic data sampling strategies, choose one from: `sheared_llama` or `none` (static). Default: `none` | +| `--moe_calculator_score_scale_factor` | Scale factor to multiply after hidden states are procesed by experts. Should be $\frac{\text{\#total experts}}{\text{\#selected}}$. Default: `4.0` | +| `--num_selects` | The number of selected experts. Default: `4` | +| `--gate_balance_loss_weight` | The weight of the balance loss for the gate. Default: `1e-2` | + +## 📋 Checklist before Starting an Experiment + +- [ ] balance loss weight +- [ ] scale factor +- [ ] learning rate +- [ ] warmup steps +- [ ] evaluation steps +- [ ] logging steps +- [ ] global batch size +- [ ] number of selected experts +- [ ] pretrained model +- [ ] data path +- [ ] GPUs +- [ ] comment + +## ⚙️ Configuration Instructions for Slurm Users For `scripts/cpt/lora.sh` and `scripts/cpt/fpt.sh` files, we could run an experiment via `sbatch`. e.g. `sbatch scripts/cpt/lora.sh` . @@ -36,12 +68,6 @@ llama1-7b 16 select 4: 3.49b params llama1-13b total params: 13,015,864,320 - total mlp params: 8,493,465,600 -| total experts | selected | dropped params | added gate params | total params | -| ------------: | -------: | -------------: | ----------------: | ------------: | -| 16 | 8 | 4,246,732,800 | 3,287,040 | 8,772,418,560 | -| 16 | 4 | 6,370,099,200 | 3,287,040 | 6,649,052,160 | -| 16 | 2 | 7,431,782,400 | 3,287,040 | 5,587,368,960 | - ## 🧮 Estimation of Training Speed and Tokens For convenient estimation of the model training speed, we provide some useful information at the very beginning of log files: @@ -77,18 +103,3 @@ Here, the `short_name` is an abbreviation for your task, and the port number cou ```bash $ tensorboard --logdir_spec moe_from_scratch:outputs/cpt-llama-moe-scratch-lora-bs16-1476932/runs/Jul26_21-53-42_SH-IDCA1404-10-140-54-121,moe_lora:outputs/cpt-llama-lora-bs16-1476918/runs/Jul26_21-31-09_SH-IDCA1404-10-140-54-122 --port 8001 ``` - -## 📋 Checklist before Starting an Experiment - -- [ ] balance loss weight -- [ ] scale factor -- [ ] learning rate -- [ ] warmup steps -- [ ] evaluation steps -- [ ] logging steps -- [ ] global batch size -- [ ] number of selected experts -- [ ] pretrained model -- [ ] data path -- [ ] GPUs -- [ ] comment diff --git a/docs/imgs/title-favicon.png b/docs/imgs/title-favicon.png new file mode 100644 index 0000000..b1f900c --- /dev/null +++ b/docs/imgs/title-favicon.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:656e5f3de4440b469d9b7bb928a14872dfb69329b7d539bce620cdef782d804c +size 1167538 diff --git a/example.py b/example.py new file mode 100644 index 0000000..b32c8c8 --- /dev/null +++ b/example.py @@ -0,0 +1,17 @@ +import torch +from transformers import AutoTokenizer + +from smoe.models.llama_moe import LlamaMoEForCausalLM + +model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2342244/checkpoint-13600/" +tokenizer = AutoTokenizer.from_pretrained(model_dir) +model = LlamaMoEForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) +model.to("cuda:0") + +input_text = "Suzhou is famous of" +inputs = tokenizer(input_text, return_tensors="pt") +inputs = inputs.to("cuda:0") + +pred = model.generate(**inputs, max_length=50, temperature=0.0) +print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) +# Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three diff --git a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh new file mode 100644 index 0000000..ff4c760 --- /dev/null +++ b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh @@ -0,0 +1,171 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_split_112gpus_16_2_scale_factor_4 +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 2/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=2 + scale_factor=4.0 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --num_selects ${num_selects} \ + --moe_calculator_score_scale_factor ${scale_factor} \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" +} diff --git a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh new file mode 100644 index 0000000..51a90fd --- /dev/null +++ b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh @@ -0,0 +1,172 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_split_112gpus_16_2_scale_factor_8 +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 2/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=2 + scale_factor=8.0 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --resume_from_checkpoint "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2340407/checkpoint-1020/" \ + --prob_map "sheared_llama" \ + --num_selects ${num_selects} \ + --moe_calculator_score_scale_factor ${scale_factor} \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" +} diff --git a/smoe/entrypoint/analysis/gate_load_vis.py b/smoe/entrypoint/analysis/gate_load_vis.py index 692ea08..cbabd1c 100644 --- a/smoe/entrypoint/analysis/gate_load_vis.py +++ b/smoe/entrypoint/analysis/gate_load_vis.py @@ -12,6 +12,20 @@ from smoe.models.llama_moe import LlamaMoEForCausalLM from smoe.utils.visualization.visualize import visualize_expert_load_heatmap +NAME_MAP = { + "en_wikipedia": "Wikipedia", + "github": "GitHub", + "en_arxiv": "arXiv", + "en_book": "Book", + "en_cc": "CommonCrawl", + "en_c4": "C4", + "en_stack": "StackExchange", + "arc_challenge": "ARC-c", + "gsm8k": "GSM-8K", + "hellaswag": "HellaSwag", + "mmlu": "MMLU", +} + @torch.no_grad() def main( @@ -19,6 +33,7 @@ def main( result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/", ): bsz = 4 + num_batch = 1 # 128 # model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-16Select4-688Neurons-Share" # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage1-2090437/checkpoint-4000" # model_dir = "/mnt/petrelfs/zhutong/smoe/outputs/cpt-7b-4_16_noisygate-gate_stage2-2105807/checkpoint-4000" @@ -35,6 +50,19 @@ def main( "hellaswag": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/hellaswag.jsonl", "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", # 23720 tokens } + # eval_path_map = { + # "en_wikipedia": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_wikipedia/part-000838-79b0b564.jsonl", + # "github": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/github/part-000113-79b0b564.jsonl", + # "en_stack": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_stack/part-001298-79b0b564.jsonl", + # "en_cc": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_cc/part-000113-79b0b564.jsonl", + # "en_c4": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_c4/part-001298-79b0b564.jsonl", + # "en_book": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_book/part-002145-79b0b564.jsonl", + # "en_arxiv": "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_arxiv/part-000113-79b0b564.jsonl", + # "arc_challenge": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/arc_challenge.jsonl", + # "gsm8k": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/gsm8k.jsonl", # 37998 tokens + # "hellaswag": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/hellaswag.jsonl", + # "mmlu": "/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/mmlu.jsonl", # 23720 tokens + # } # result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_gradient_share_gate_load/stage1_trained_more/" result_dir = Path(result_dir) @@ -59,11 +87,10 @@ def main( eval_dataset, batch_size=bsz, collate_fn=fault_tolerance_data_collator ) loader = accel.prepare_data_loader(loader) - if name == "en_book": - num_batch = 20 - else: - num_batch = 9999999999999999 - num_batch = 1 + # if name == "en_book": + # num_batch = 20 + # else: + # num_batch = 9999999999999999 for batch_idx, batch in enumerate(tqdm(loader, desc=name)): if batch_idx >= num_batch: break @@ -103,7 +130,7 @@ def heatmap( text = ax.text( j, i, - f"{arr[i, j]:.1%}", + f"{arr[i, j]:.3}", ha="center", va="center", color="black", @@ -111,8 +138,8 @@ def heatmap( ) ax.set_xticks(range(len(xlabels))) ax.set_yticks(range(len(ylabels))) - ax.set_xticklabels(xlabels, rotation=45, ha="right") - ax.set_yticklabels(ylabels) + ax.set_xticklabels([NAME_MAP[n] for n in xlabels], rotation=45, ha="right") + ax.set_yticklabels([NAME_MAP[n] for n in ylabels]) ax.set_title(title) fig.colorbar(im) fig.tight_layout() @@ -124,50 +151,55 @@ def calc_sim( gate_load_folder="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_baseline_gate_load/", layer_idx=0, plot=True, + plot_type="train-train", # or dev-train ): # title = "SlimPajama" # sim_pairs = [["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], ["wiki", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"]] - # title = "Dev vs. SlimPajama" - # sim_pairs = [ - # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - # ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], - # ] - # title = "Dev vs. Dev" - # sim_pairs = [ - # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], - # ] - # title = "test" - # sim_pairs = [["wiki", "github"], ["wiki", "github"]] - title = f"Routing Similarity Layer {layer_idx}" - sim_pairs = [ - [ - "arc_challenge", - "gsm8k", - "hellaswag", - "mmlu", - "en_wikipedia", - "github", - "en_stack", - "en_cc", - "en_c4", - "en_book", - "en_arxiv", - ], - [ - "arc_challenge", - "gsm8k", - "hellaswag", - "mmlu", - "en_wikipedia", - "github", - "en_stack", - "en_cc", - "en_c4", - "en_book", - "en_arxiv", - ], - ] + + if plot_type == "train-train": + title = "Train vs. Train" + sim_pairs = [ + [ + "en_wikipedia", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + "github", + "en_stack", + ], + [ + "en_wikipedia", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + "github", + "en_stack", + ], + ] + elif plot_type == "dev-dev": + title = "Dev vs. Dev" + sim_pairs = [ + ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], + ] + elif plot_type == "dev-train": + title = "Dev vs. Train" + sim_pairs = [ + ["hellaswag", "arc_challenge", "mmlu", "gsm8k"], + [ + "en_wikipedia", + "en_cc", + "en_c4", + "en_book", + "en_arxiv", + "github", + "en_stack", + ], + ] + else: + raise ValueError folder = Path(gate_load_folder) name2arr = {} @@ -175,8 +207,14 @@ def calc_sim( for dtype in folder.glob("*" + suffix): name = dtype.name[: -len(suffix)] arr = np.load(folder / f"{name}{suffix}") - # min-max - name2arr[name] = arr[layer_idx] / arr[layer_idx].max() + # name2arr[name] = arr[layer_idx] + # name2arr[name] = arr[layer_idx] / arr[layer_idx].sum() + name2arr[name] = (arr[layer_idx] - arr[layer_idx].min()) / ( + arr[layer_idx].max() - arr[layer_idx].min() + ) + + # # min-max + # name2arr[name] = arr[layer_idx] / arr[layer_idx].max() # # softmax # layer_arr = arr[layer_idx] # e_x = np.exp(layer_arr - layer_arr.max()) @@ -187,31 +225,40 @@ def calc_sim( t1_load = name2arr[type1] for t2_idx, type2 in enumerate(sim_pairs[1]): t2_load = name2arr[type2] - _sim = np.dot(t1_load, t2_load) / ( - np.linalg.norm(t1_load) * np.linalg.norm(t2_load) - ) + # _sim = np.dot(t1_load, t2_load) / ( + # np.linalg.norm(t1_load, 2) * np.linalg.norm(t2_load, 2) + # ) # _sim = 1.0 - np.linalg.norm(t1_load - t2_load, 2) + # _sim = -np.linalg.norm(t1_load - t2_load, 2) + _sim = np.linalg.norm(t1_load - t2_load, 2) + # _sim = 1.0 - np.sqrt(np.power(t1_load - t2_load, 2).sum()) + # _sim = -np.sqrt(np.power(t1_load - t2_load, 2).sum()) sim_arr[t1_idx][t2_idx] = _sim if plot: heatmap( sim_arr, sim_pairs[1], sim_pairs[0], - str(folder / f"layer{layer_idx}" / f"cos_sim_{title}.png"), + str(folder / f"layer{layer_idx + 1}" / f"cos_sim_{title}.pdf"), title, ) return sim_arr -def gate_load_vis(): - model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2326233/checkpoint-5440" - result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_80B_gate_load/" - main( - # w/ fluency filtering, 85b - model_dir=model_dir, - result_dir=result_dir, - ) +def gate_load_vis( + model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2356022/checkpoint-13600/", + result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/", +): + # main( + # # w/ fluency filtering, 85b + # model_dir=model_dir, + # result_dir=result_dir, + # ) + # main( + # model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2356022/checkpoint-13600/", + # result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/", + # ) sim_arr_list = [] for layer_idx in range(32): @@ -227,7 +274,7 @@ def gate_load_vis(): # ["arc_challenge", "gsm8k", "hellaswag", "mmlu"], # ["en_wikipedia", "github", "en_stack", "en_cc", "en_c4", "en_book", "en_arxiv"], # ] - title = "Routing Similarity" + title = "Routing Differences" sim_pairs = [ [ "arc_challenge", @@ -260,7 +307,7 @@ def gate_load_vis(): sim_arr, sim_pairs[1], sim_pairs[0], - f"{result_dir}/cos_sim_avg.png", + f"{result_dir}/cos_sim_avg_{title}.pdf", title, ) @@ -274,7 +321,7 @@ def gate_load_vis_from_cache(name, cache_filepath, result_dir, minmax: bool = Fa for layer_idx in range(gate_load_sum.shape[0]): visualize_expert_load_heatmap( gate_load_sum[layer_idx], - layer_idx, + layer_idx + 1, name, shape=(4, 4), save_dir=str(result_dir), @@ -282,13 +329,69 @@ def gate_load_vis_from_cache(name, cache_filepath, result_dir, minmax: bool = Fa ) +def gate_load_var_trend(paths, output_figpath): + data_list = [] + var_list = [] + layer_list = [] + + tmp = np.load(paths[0]) + num_layers, num_experts = tmp.shape + + for path in paths: + data = np.load(path) + data_list.append(data) + for layer_idx in range(num_layers): + layer_list.append(layer_idx + 1) + loads = [] + for data in data_list: + loads.append(data[layer_idx].flatten()) + _var = np.var(np.stack(loads, axis=0), axis=0).sum() + var_list.append(_var) + + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(layer_list, var_list) + ax.set_xlabel("Layer") + ax.set_ylabel("Variance") + fig.savefig(output_figpath, dpi=320, bbox_inches="tight") + plt.close() + + if __name__ == "__main__": - main( - model_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2326233/checkpoint-6120", - result_dir="/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_90B_gate_load/", + """ + srun -p MoE -n1 -N1 --gres=gpu:1 python -m smoe.entrypoint.analysis.gate_load_vis + """ + model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/outputs/cpt-llama2_random_scale4_112gpus_dynamic_data-2356022/checkpoint-13600" + result_dir = "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load_more_tokens" + main(model_dir=model_dir, result_dir=result_dir) + for name in NAME_MAP.keys(): + title = NAME_MAP[name] + gate_load_vis_from_cache( + title, + f"{result_dir}/{name}_gate_load.npy", + f"{result_dir}/{name}", + minmax=False, + ) + calc_sim( + result_dir, + layer_idx=31, + plot_type="dev-train", + plot=True, + ) + calc_sim( + result_dir, + layer_idx=31, + plot_type="train-train", + plot=True, ) - # gate_load_vis() + # gate_load_vis(model_dir=model_dir, result_dir=result_dir) + + # calc_sim( + # "/mnt/petrelfs/zhutong/smoe/results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/", + # layer_idx=27, + # plot=True, + # ) # for name in ["gsm8k", "mmlu"]: # gate_load_vis_from_cache( @@ -297,3 +400,42 @@ def gate_load_vis_from_cache(name, cache_filepath, result_dir, minmax: bool = Fa # f"results/llama2_7B_random_split_sheared_sampling_fluency_85B_gate_load/{name}", # minmax=True, # ) + + # for name in NAME_MAP.keys(): + # title = NAME_MAP[name] + # gate_load_vis_from_cache( + # title, + # f"results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/{name}_gate_load.npy", + # f"results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/{name}", + # minmax=False, + # ) + + # gate_load_var_trend( + # [ + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_arxiv_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_book_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_c4_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_cc_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_stack_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_wikipedia_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/github_gate_load.npy", + # ], + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/var_trend.pdf", + # ) + + # filepaths = [ + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_arxiv_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_book_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_c4_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_cc_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_stack_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/en_wikipedia_gate_load.npy", + # "results/llama2_7B_random_split_sheared_sampling_fluency_200B_gate_load/github_gate_load.npy", + # ] + # mins = [] + # maxs = [] + # for path in filepaths: + # data = np.load(path) + # mins.append(data.min()) + # maxs.append(data.max()) + # print(min(mins), max(maxs)) diff --git a/smoe/utils/param_estimation.py b/smoe/utils/param_estimation.py index b3cfc60..3d4c21d 100644 --- a/smoe/utils/param_estimation.py +++ b/smoe/utils/param_estimation.py @@ -49,18 +49,82 @@ def estimate_moe_param( } +def normal_moe_param( + vocab_size, + hidden_size, + num_hidden_layers, + intermediate_size, + num_experts, + num_selects, + kv_attn_ratio: float = 1.0, +): + emb = vocab_size * hidden_size + lm_head = vocab_size * hidden_size + final_norm = hidden_size + + self_attn = ( + hidden_size * hidden_size * 2 + hidden_size * hidden_size * kv_attn_ratio * 2 + ) + mlp = hidden_size * intermediate_size * 3 + input_norm = hidden_size + post_attn_norm = hidden_size + + dense_one_layer = self_attn + mlp + input_norm + post_attn_norm + dense_mid = dense_one_layer * num_hidden_layers + dense_params = emb + lm_head + final_norm + dense_mid + + gate = hidden_size * num_selects + moe_one_layer = self_attn + mlp * num_experts + input_norm + post_attn_norm + gate + moe_one_layer_act = ( + self_attn + mlp * num_selects + input_norm + post_attn_norm + gate + ) + moe_mid = moe_one_layer * num_hidden_layers + moe_tot_params = emb + lm_head + final_norm + moe_mid + moe_act_mid = moe_one_layer_act * num_hidden_layers + moe_act_params = emb + lm_head + final_norm + moe_act_mid + + return { + "dense_params": dense_params, + "moe_tot_params": moe_tot_params, + "moe_act_params": moe_act_params, + "dense_mid": dense_mid, + "moe_mid": moe_mid, + "moe_act_mid": moe_act_mid, + } + + if __name__ == "__main__": - # 3B + # opt-2.7b: 2651596800 + opt = normal_moe_param(50272, 2560, 32, 10240, 1, 1) + print("opt-2.7b", opt) + + # pythia-2.8b: 2775208960 + pythia = normal_moe_param(50304, 2560, 32, 10240, 1, 1) + print("pythia-2.8b", pythia) + + # incite-base-3b: 2775864320 + incite = normal_moe_param(50432, 2560, 32, 10240, 1, 1) + print("incite", incite) + + # 3B: open-llama-3b-v2: 3426473600 res_3B = estimate_moe_param(32000, 3200, 26, 8640, 16, 4) print("3B", res_3B) # 7B res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 4) print("7B", res_7B) + res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 2) + print("7B 2/16", res_7B) + res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 1) + print("7B 1/16", res_7B) # 13B res_13B = estimate_moe_param(32000, 5120, 40, 13824, 16, 4) - print("13B", res_13B) + print("13B 4/16", res_13B) + res_13B = estimate_moe_param(32000, 5120, 40, 13824, 16, 2) + print("13B 2/16", res_13B) + res_13B = estimate_moe_param(32000, 5120, 40, 13824, 16, 1) + print("13B 1/16", res_13B) # 3B upcycling for num_experts in range(1, 9): @@ -90,3 +154,7 @@ def estimate_moe_param( print("7B half 24 layers", res_7B_half) res_7B_half = estimate_moe_param(32000, 4096, 16, 11008, 16, 1) print("7B half 16 layers 1/16", res_7B_half) + + # mixtral 7Bx8 + res_mixtral = normal_moe_param(32000, 4096, 32, 14336, 8, 2, kv_attn_ratio=0.25) + print("mixtral 7Bx8", res_mixtral) diff --git a/smoe/utils/tokenize.py b/smoe/utils/tokenize.py index 747ba8e..60a6e48 100644 --- a/smoe/utils/tokenize.py +++ b/smoe/utils/tokenize.py @@ -123,7 +123,7 @@ def _tokenization_func(examples): ) tokenized_ds.to_json(output_filepath, lines=True, num_proc=args.num_proc) - prepare_meta(output_filepath) + # prepare_meta(output_filepath) if input_path.is_dir(): input_files = list(input_path.glob(f"*.{args.format}")) diff --git a/smoe/utils/visualization/visualize.py b/smoe/utils/visualization/visualize.py index 6e5c033..8e0437d 100644 --- a/smoe/utils/visualization/visualize.py +++ b/smoe/utils/visualization/visualize.py @@ -300,6 +300,8 @@ def visualize_expert_load_heatmap( if save_dir_path.is_file(): raise ValueError(f"{save_dir} is a file, not a directory") save_dir_path.mkdir(exist_ok=True, parents=True) + # path = save_dir_path / Path(f"{dataset_name}_Layer{layer_idx}.pdf") + # print(layer_idx, path) path = save_dir_path / Path(f"{dataset_name}_Layer{layer_idx}.png") data = load_sum.reshape(*shape) @@ -308,6 +310,7 @@ def visualize_expert_load_heatmap( fig = plt.figure() ax = fig.add_subplot(111) im = ax.imshow(data, cmap=cmap, interpolation="nearest") + # im = ax.imshow(data, cmap=cmap, interpolation="nearest", vmin=3500, vmax=4500) for i in range(shape[0]): for j in range(shape[1]): @@ -319,7 +322,8 @@ def visualize_expert_load_heatmap( fig.tight_layout() if save_fig: fig.savefig(str(path), dpi=320, bbox_inches="tight") - compress_png_image(str(path), print_info=False) + if path.suffix == ".png": + compress_png_image(str(path), print_info=False) return fig From 9bb6ec6bea7858363d5977ac84554a5daf0dc6a0 Mon Sep 17 00:00:00 2001 From: zhutong Date: Fri, 15 Dec 2023 14:35:25 +0800 Subject: [PATCH 05/12] update 8_2 script and fp32 softmax gating --- ...s_8_2_sheared_llama_portion_fluency_sf4.sh | 171 ++++++++++++++++++ smoe/modules/moe/moe_gates.py | 3 +- smoe/utils/param_estimation.py | 2 + 3 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh diff --git a/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh new file mode 100644 index 0000000..177c55a --- /dev/null +++ b/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh @@ -0,0 +1,171 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_split_112gpus_8_2 +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_8_2/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_8_2/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 2/8, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-8Select2-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-no-ad-processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=2 + scale_factor=4.0 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_8_2" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --num_selects ${num_selects} \ + --moe_calculator_score_scale_factor ${scale_factor} \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" +} diff --git a/smoe/modules/moe/moe_gates.py b/smoe/modules/moe/moe_gates.py index 637840a..c1de523 100644 --- a/smoe/modules/moe/moe_gates.py +++ b/smoe/modules/moe/moe_gates.py @@ -295,7 +295,8 @@ def forward(self, x): top_logits, top_indices = logits.topk(min(self.num_selects + 1, self.num_experts), dim=1) # 选择并排序前k+1个权重 top_k_logits = top_logits[:, :self.num_selects] top_k_indices = top_indices[:, :self.num_selects] - top_k_scores = self.softmax(top_k_logits) if self.use_softmax else top_k_logits + top_k_scores = self.softmax(top_k_logits.to(torch.float32)) if self.use_softmax else top_k_logits + top_k_scores = top_k_scores.to(logits.dtype) """计算importance""" zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) diff --git a/smoe/utils/param_estimation.py b/smoe/utils/param_estimation.py index 3d4c21d..7f4a7ed 100644 --- a/smoe/utils/param_estimation.py +++ b/smoe/utils/param_estimation.py @@ -117,6 +117,8 @@ def normal_moe_param( print("7B 2/16", res_7B) res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 1) print("7B 1/16", res_7B) + res_7B = estimate_moe_param(32000, 2560, 32, 11008, 8, 2) + print("7B-2560", res_7B) # 13B res_13B = estimate_moe_param(32000, 5120, 40, 13824, 16, 4) From 338ee9e3a2b1008de16de952271db12d47fc27c2 Mon Sep 17 00:00:00 2001 From: Tong Zhu Date: Sat, 16 Dec 2023 01:08:10 +0800 Subject: [PATCH 06/12] update mixtral support --- .vscode/launch.json | 9 +- requirements.txt | 1 + ...s_8_2_sheared_llama_portion_fluency_sf4.sh | 149 ++ scripts/tokenize/slimpajama_convert.sh | 52 + smoe/entrypoint/cpt/cpt_fpt.py | 5 + smoe/models/mixtral/__init__.py | 62 + smoe/models/mixtral/configuration_mixtral.py | 311 ++++ smoe/models/mixtral/modeling_mixtral.py | 1643 +++++++++++++++++ smoe/utils/cache_utils.py | 326 ++++ smoe/utils/modeling_attn_mask_utils.py | 472 +++++ smoe/utils/split_files.py | 56 + smoe/utils/tokenize.py | 56 +- tools/check_killed.py | 73 + tools/listen.py | 45 + tools/queue_submit.py | 144 ++ tools/scl_jobs.sh | 45 + 16 files changed, 3437 insertions(+), 12 deletions(-) create mode 100644 scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh create mode 100644 scripts/tokenize/slimpajama_convert.sh create mode 100644 smoe/models/mixtral/__init__.py create mode 100644 smoe/models/mixtral/configuration_mixtral.py create mode 100644 smoe/models/mixtral/modeling_mixtral.py create mode 100644 smoe/utils/cache_utils.py create mode 100644 smoe/utils/modeling_attn_mask_utils.py create mode 100644 smoe/utils/split_files.py create mode 100644 tools/check_killed.py create mode 100644 tools/listen.py create mode 100644 tools/queue_submit.py create mode 100644 tools/scl_jobs.sh diff --git a/.vscode/launch.json b/.vscode/launch.json index b2bde74..96f81b0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,13 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "tokenize", + "type": "python", + "request": "launch", + "module": "smoe.utils.tokenize", + "justMyCode": true + }, { "name": "Python: Remote Attach", "type": "python", @@ -21,4 +28,4 @@ "justMyCode": false } ] -} +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 98417b0..5129821 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,3 +38,4 @@ numpy==1.25.0 opencv-python==4.8.1.78 pynvml==11.5.0 PyYaml==6.0.1 +pandas<2.1.0 diff --git a/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh new file mode 100644 index 0000000..bed9b34 --- /dev/null +++ b/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh @@ -0,0 +1,149 @@ +#!/usr/bin/bash + +#SBATCH --job-name=mxitral_random_split_112gpus_8_2 +#SBATCH --output=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log + +#SBATCH --partition=MoE_T +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=2 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=2 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + model_type="mixtral" + comment="mistral 7B, random 2/8, sheared llama data portion" + pretrained_model=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2 + tokenizer_path=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2 + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=2 + scale_factor=4.0 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + # warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + warmup_steps=100 + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + # eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + eval_steps=340 + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --num_selects ${num_selects} \ + --moe_calculator_score_scale_factor ${scale_factor} \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps ${warmup_steps} \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" +} diff --git a/scripts/tokenize/slimpajama_convert.sh b/scripts/tokenize/slimpajama_convert.sh new file mode 100644 index 0000000..1e510db --- /dev/null +++ b/scripts/tokenize/slimpajama_convert.sh @@ -0,0 +1,52 @@ +#!/usr/bin/bash + +# set -vx + +content_column=input_ids +src_tokenizer_dir=/mnt/petrelfs/share_data/zhutong/models/llama2_7B +tokenizer_dir=/mnt/petrelfs/share_data/zhutong/models/Mistral-7B-v0.1 + +data_dir=/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_llama_middle_parts +out_dir=/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral_middle_parts +# data_dir=/mnt/petrelfs/share_data/zhutong/data/llama1_7B_val_set_tokenized +# out_dir=/mnt/petrelfs/share_data/zhutong/data/mixtral_val_set_tokenized + + +logs_dir=logs + +mkdir -p $logs_dir + +# for loop in: en_arxiv, en_book, en_c4, en_cc, en_stack, en_wikipedia, github +# for data_type in $(ls $data_dir) +for data_type in "en_arxiv" "en_book" "en_c4" "en_stack" "en_wikipedia" "github" +do + # get all parts from source data dir + for part in $(ls $data_dir/$data_type) + do + echo "tokenizing $data_dir/$data_type/$part - $(ls $data_dir/$data_type/$part | wc -l)" + log_path=logs/tokenize-$data_type-$part.log + nohup srun -p MoE_T -N1 -n1 --cpus-per-task=32 \ + python -m smoe.utils.tokenize \ + -f jsonl \ + -c $content_column \ + -s $src_tokenizer_dir \ + -t $tokenizer_dir \ + -i $data_dir/$data_type/$part \ + -o $out_dir/$data_type/$part \ + 1>$log_path 2>&1 & + # echo "$data_type/$part > $log_path" + sleep 3 + done + + # log_path=logs/tokenize_$data_type.log + # nohup srun -p MoE_T -N1 -n1 --cpus-per-task=32 \ + # python -m smoe.utils.tokenize \ + # -f jsonl \ + # -s $src_tokenizer_dir \ + # -c $content_column \ + # -t $tokenizer_dir \ + # -i $data_dir/$data_type \ + # -o $out_dir/$data_type \ + # 1>$logs_dir/tokenize_$data_type.log 2>&1 & + # echo "$data_type > $logs_dir/tokenize_$data_type.log" +done diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index 8dcfe80..da4a5e5 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -36,6 +36,8 @@ LlamaMoEResidualConfig, LlamaMoEResidualForCausalLM, ) +from smoe.models.mixtral.configuration_mixtral import MixtralConfig +from smoe.models.mixtral.modeling_mixtral import MixtralForCausalLM from smoe.modules.flash_attn import replace_xformers from smoe.trainer.llama_lr_scheduling import LlamaLrSchedulingTrainer from smoe.utils.config import ( @@ -51,6 +53,7 @@ "llama": LlamaForCausalLM, "llama_moe": LlamaMoEForCausalLM, "llama_moe_residual": LlamaMoEResidualForCausalLM, + "mixtral": MixtralForCausalLM, } CONFIG_MAPPING.update( @@ -58,6 +61,7 @@ "llama": LlamaConfig, "llama_moe": LlamaMoEConfig, "llama_moe_residual": LlamaMoEResidualConfig, + "mixtral": MixtralConfig, } ) @@ -276,6 +280,7 @@ def main(): # model.half() # model.to(torch_dtype) + # TODO (tzhu): add flash-attn for mixtral model: LlamaForCausalLM | LlamaMoEForCausalLM | LlamaMoEResidualForCausalLM = ( ModelClass.from_pretrained( model_args.model_name_or_path, diff --git a/smoe/models/mixtral/__init__.py b/smoe/models/mixtral/__init__.py new file mode 100644 index 0000000..4ad441d --- /dev/null +++ b/smoe/models/mixtral/__init__.py @@ -0,0 +1,62 @@ +# Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mixtral"] = [ + "MixtralForCausalLM", + "MixtralModel", + "MixtralPreTrainedModel", + "MixtralForSequenceClassification", + ] + + +if TYPE_CHECKING: + from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mixtral import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralModel, + MixtralPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file diff --git a/smoe/models/mixtral/configuration_mixtral.py b/smoe/models/mixtral/configuration_mixtral.py new file mode 100644 index 0000000..207f163 --- /dev/null +++ b/smoe/models/mixtral/configuration_mixtral.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mixtral model configuration""" + +import copy +from typing import Dict, Any + +from transformers import __version__ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json", +} + + +def recursive_diff_dict(dict_a, dict_b, config_obj=None): + """ + Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the + values from `dict_a` that are different from values in `dict_b`. + """ + diff = {} + default = config_obj.__class__().to_dict() if config_obj is not None else {} + for key, value in dict_a.items(): + obj_value = getattr(config_obj, str(key), None) + if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict): + diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value) + if len(diff_value) > 0: + diff[key] = diff_value + elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]: + diff[key] = value + return diff + + +class MixtralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an + Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. + + [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B) + [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MixtralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import MixtralModel, MixtralConfig + + >>> # Initializing a Mixtral 7B style configuration + >>> configuration = MixtralConfig() + + >>> # Initializing a model from the Mixtral 7B style configuration + >>> model = MixtralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mixtral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=4096, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + self.scale_factor = kwargs.pop("scale_factor", 1.0) + # Attention implementation to use, if relevant. + self._attn_implementation_internal = kwargs.pop("attn_implementation", None) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] + + # Transformers version when serializing the model + output["transformers_version"] = __version__ + + for key, value in output.items(): + # Deal with nested configs like CLIP + if isinstance(value, PretrainedConfig): + value = value.to_dict() + del value["transformers_version"] + + output[key] = value + + if hasattr(self, "quantization_config"): + output["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(output) + + return output + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) + ): + # For nested configs we need to clean the diff recursively + diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) + if "model_type" in value: + # Needs to be set even if it's not in the diff + diff["model_type"] = value["model_type"] + if len(diff) > 0: + serializable_config_dict[key] = diff + elif ( + key not in default_config_dict + or key == "transformers_version" + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + if hasattr(self, "quantization_config"): + serializable_config_dict["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = serializable_config_dict.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(serializable_config_dict) + + if "_attn_implementation_internal" in serializable_config_dict: + del serializable_config_dict["_attn_implementation_internal"] + + return serializable_config_dict diff --git a/smoe/models/mixtral/modeling_mixtral.py b/smoe/models/mixtral/modeling_mixtral.py new file mode 100644 index 0000000..fa8c5f3 --- /dev/null +++ b/smoe/models/mixtral/modeling_mixtral.py @@ -0,0 +1,1643 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mixtral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available + +from smoe.utils.cache_utils import Cache, DynamicCache +from smoe.utils.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, +) + +from .configuration_mixtral import MixtralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2 +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts]. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None: + return 0 + + if isinstance(gate_logits, tuple): + # cat along the layers? + compute_device = gate_logits[0].device + gate_logits = torch.cat( + [gate.to(compute_device) for gate in gate_logits], dim=0 + ) + + routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) + return torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1) + ) * (num_experts**2) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], + dim=-1, + ) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class MixtralBLockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, +} + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList( + [MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MixtralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a sequence classification head on top (linear layer). + + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/smoe/utils/cache_utils.py b/smoe/utils/cache_utils.py new file mode 100644 index 0000000..8220705 --- /dev/null +++ b/smoe/utils/cache_utils.py @@ -0,0 +1,326 @@ +""" +Borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch + + +class Cache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def __init__(self) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + """ + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_cache = {} + self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, cos[: self.window_length], sin[: self.window_length] + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) diff --git a/smoe/utils/modeling_attn_mask_utils.py b/smoe/utils/modeling_attn_mask_utils.py new file mode 100644 index 0000000..94cc79a --- /dev/null +++ b/smoe/utils/modeling_attn_mask_utils.py @@ -0,0 +1,472 @@ +# This is a script borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `attention_mask` is + ``` + [[0, 0, 1], + [1, 1, 1], + [0, 1, 1]] + ``` + and `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + + # Get the index of the first non-zero value for every sample in the batch. + # In the above example, indices = [[2], [0], [1]]] + tmp = torch.arange(attention_mask.shape[1], 0, -1) + indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) + + # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the + # expanded mask will be completely unattended. + left_masked_rows = torch.where(indices > 0)[0] + + if left_masked_rows.shape[0] == 0: + return expanded_mask + indices = indices[left_masked_rows] + + max_len = torch.max(indices) + range_tensor = torch.arange(max_len).unsqueeze(0) + range_tensor = range_tensor.repeat(indices.size(0), 1) + + # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. + range_tensor[range_tensor >= indices] = 0 + + # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case + if expanded_mask.dim() == 4: + num_masks = expanded_mask.shape[1] + if num_masks == 1: + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], 0, range_tensor) + else: + # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len] + mask_slice = ( + left_masked_rows[:, None, None], + torch.arange(num_masks)[None, :, None], + range_tensor[:, None, :], + ) + else: + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], range_tensor) + + expanded_mask[mask_slice] = unmasked_value + + return expanded_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + batch_size, query_length = input_shape + + # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() + + if attention_mask is not None: + if torch.all(attention_mask == 1): + if is_tracing: + pass + elif query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + attention_mask = None + elif key_value_length == query_length: + attention_mask = None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + pass + elif query_length > 1 and key_value_length != query_length: + # See the comment above (https://github.com/pytorch/pytorch/issues/108108). + # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. + attention_mask = True + elif is_tracing: + raise ValueError( + 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' + ) + + if attention_mask is None: + expanded_4d_mask = None + elif attention_mask is True: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if query_length > 1: + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, attention_mask, unmasked_value=0.0 + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + batch_size, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() + + if torch.all(mask == 1): + if is_tracing: + pass + elif tgt_len == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + return None + elif key_value_length == tgt_len: + return None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/smoe/utils/split_files.py b/smoe/utils/split_files.py new file mode 100644 index 0000000..8e2ec0c --- /dev/null +++ b/smoe/utils/split_files.py @@ -0,0 +1,56 @@ +""" +split files in a folder to separate folders + +src: en_arxiv/* +tgt: output/part0, output/part1, ... +""" + +from pathlib import Path + + +def split_files(src_dir, tgt_dir, num_parts): + src_dir = Path(src_dir) + tgt_dir = Path(tgt_dir) + tgt_dir.mkdir(parents=True, exist_ok=True) + + filepaths = sorted(src_dir.glob("*.jsonl")) + num_files = len(filepaths) + num_files_per_part = num_files // num_parts + print(f"{src_dir} --> {tgt_dir}") + print(f"num_files_per_part: {num_files_per_part}") + + for i in range(num_parts): + start = i * num_files_per_part + end = (i + 1) * num_files_per_part + if i == num_parts - 1: + end = num_files + print(f"part-{i}, start: {start}, end: {end}") + + part_dir = tgt_dir / f"part-{i:06d}" + part_dir.mkdir(parents=True, exist_ok=True) + for j in range(start, end): + filepath = filepaths[j] + tgt_filepath = part_dir / filepath.name + tgt_filepath.symlink_to(filepath) + + +if __name__ == "__main__": + for data_type in [ + # "en_arxiv", + # "en_book", + # "en_c4", + "en_cc", + # "en_stack", + # "en_wikipedia", + # "github", + ]: + split_files( + f"/mnt/hwfile/share_data/zhutong/slimpajama_fluency_llama/{data_type}", + f"/mnt/hwfile/share_data/zhutong/data/slimpajama_fluency_llama_middle_parts/{data_type}", + 30, + ) + # split_files( + # "/mnt/hwfile/share_data/zhutong/slimpajama_fluency_llama/en_arxiv", + # "/mnt/hwfile/share_data/zhutong/data/slimpajama_fluency_llama_middle_parts/en_arxiv", + # 30, + # ) diff --git a/smoe/utils/tokenize.py b/smoe/utils/tokenize.py index 60a6e48..446633f 100644 --- a/smoe/utils/tokenize.py +++ b/smoe/utils/tokenize.py @@ -13,7 +13,21 @@ def get_parser(): parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--src_tokenizer", + required=False, + default=None, + help="source tokenizer filepath", + ) parser.add_argument("-t", "--tokenizer", required=True, help="tokenizer filepath") + parser.add_argument( + "-c", + "--content_column", + required=False, + default="content", + help="content column name", + ) parser.add_argument( "-i", "--input", required=True, help="filepath or dir with jsonl to tokenize" ) @@ -37,21 +51,21 @@ def get_parser(): return args -def load_jsonlines(filepath): +def load_jsonlines(filepath, content_column: str = "content"): data = [] with open(filepath, "r", encoding="utf8") as fin: for line in tqdm(fin, desc="Loading"): ins = json.loads(line) - if "content" in ins: - data.append({"content": ins["content"]}) + if content_column in ins: + data.append({content_column: ins[content_column]}) return data -def load_txt(filepath): +def load_txt(filepath, content_column: str = "content"): data = [] with open(filepath, "r", encoding="utf8") as fin: for line in tqdm(fin, desc="Loading"): - data.append({"content": line.strip()}) + data.append({content_column: line.strip()}) return data @@ -85,6 +99,12 @@ def prepare_meta(jsonl_filepath: str): def tokenize_jsonl(): args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=args.use_fast) + if args.src_tokenizer is not None: + src_tokenizer = AutoTokenizer.from_pretrained( + args.src_tokenizer, use_fast=args.use_fast + ) + else: + src_tokenizer = None input_path = Path(args.input) output_path = Path(args.output) @@ -100,20 +120,27 @@ def tokenize_jsonl(): def _tokenize_and_dump(input_filepath, output_filepath): if args.format == "jsonl": - data = load_jsonlines(input_filepath) + data = load_jsonlines(input_filepath, content_column=args.content_column) elif args.format == "txt": - data = load_txt(input_filepath) + data = load_txt(input_filepath, content_column=args.content_column) else: raise ValueError(f"{args.format} format not supported") ds = Dataset.from_list(data) column_names = ds.column_names - text_column_name = "content" if "content" in column_names else column_names[0] + # text_column_name = "content" if "content" in column_names else column_names[0] + # text_column_name = args.content_column def _tokenization_func(examples): - return {"input_ids": tokenizer(examples[text_column_name])["input_ids"]} - - ds = ds.filter(lambda example: text_column_name in example) + contents = examples[args.content_column] + if src_tokenizer is not None: + # decode input_ids to text + contents = src_tokenizer.batch_decode( + contents, skip_special_tokens=True + ) + return {"input_ids": tokenizer(contents)["input_ids"]} + + ds = ds.filter(lambda example: args.content_column in example) tokenized_ds = ds.map( _tokenization_func, batched=True, @@ -163,6 +190,13 @@ def update_meta_without_tokenization(data_dir: str): if __name__ == "__main__": + # import sys + + # sys.argv = ( + # sys.argv + # + "-s /mnt/petrelfs/share_data/zhutong/models/llama2_7B -t /mnt/petrelfs/share_data/zhutong/models/llama2_7B -i /mnt/petrelfs/share_data/zhutong/slimpajama_fluency_llama/en_arxiv/part-000000-79b0b564.jsonl -o arxiv.jsonl -f jsonl -p 1 --content_column input_ids".split() + # ) + tokenize_jsonl() # # uncomment and run: srun -p MoE -c 16 python -m smoe.utils.tokenize diff --git a/tools/check_killed.py b/tools/check_killed.py new file mode 100644 index 0000000..2861b2c --- /dev/null +++ b/tools/check_killed.py @@ -0,0 +1,73 @@ +import re +import subprocess +from pathlib import Path +from collections import defaultdict, Counter + + +def get_jobstate(job_id): + cmd = f"sacct -j {job_id} -o state -n" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + ret = p.stdout.read().decode("utf8").strip() + return ret + + +def get_data_type_and_part_id(filepath): + path = Path(filepath) + obj = re.search(r"tokenize-(.*?)-part-(\d+).log", path.name) + if obj is None: + return None + data_type, part_id = obj.groups() + return data_type, part_id + + +def check_result(filepath): + path = Path(filepath) + ret = get_data_type_and_part_id(filepath) + if ret is None: + return None + data_type, part_id = ret + content = path.read_text(encoding="utf8") + + if ( + "srun: error: Unable to allocate resources: Reach max user active rpc limit" + in content + or "srun: error: Unable to allocate resources: Socket timed out on send/recv operation" + in content + ): + print(f"Error: {data_type}/{part_id}") + return "error" + + obj = re.search(r"srun: job (\d+) queued and waiting for resources", content) + if obj is None: + print(f"Unknown: {data_type}/{part_id}") + return "unknown" + + job_id = obj.group(1) + jobstate = get_jobstate(job_id) + obj = re.search(r"Tokenization Progress:\s*100%\s*\|.*\|\s*(\d+)/(\d+)", content) + if obj is not None: + progress, total = obj.groups() + if progress == total and progress is not None and total is not None and jobstate != "COMPLETED": + print(f"DEAD_COMPLETED: {data_type}/{part_id} - job: {job_id}") + return "DEAD_COMPLETED" + + print(f"{jobstate}: {data_type}/{part_id}") + return jobstate + + +if __name__ == "__main__": + status = defaultdict(list) + for filepath in Path("logs").glob("tokenize-*.log"): + s = check_result(filepath) + res = get_data_type_and_part_id(filepath) + status[s].append(res) + + print(Counter({k: len(v) for k, v in status.items()}).most_common()) + + def print_val(v, k): + print(f"# {k} = {len(v[k])}") + for path in v[k]: + print(path) + + for key in ["CANCELLED+", "DEAD_COMPLETED", "error", None]: + print_val(status, key) diff --git a/tools/listen.py b/tools/listen.py new file mode 100644 index 0000000..a4260ac --- /dev/null +++ b/tools/listen.py @@ -0,0 +1,45 @@ +import time +import subprocess + +from smoe.utils.notification import send_to_wechat + + +def check_sme_pending(): + # run sme | grep "normal PD" | wc -l, if the returned value is 0, then send a notification + cmd = "squeue --me | grep 'normal PD' | wc -l" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + for line in p.stdout.readlines(): + line = line.decode("utf-8") + if int(line) == 0: + send_to_wechat("pending jobs all clear!!!") + return True + return False + + +def check_sme_running(): + # run sme | grep "normal R" | wc -l, if the returned value is 0, then send a notification + cmd = "squeue --me | grep 'normal R' | wc -l" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + for line in p.stdout.readlines(): + line = line.decode("utf-8") + if int(line) == 0: + send_to_wechat("running jobs all clear!!!") + return True + return False + + +def listen(): + # check pending jobs every 10 seconds, if all pending jobs are done, send a notification + no_pending = False + no_running = False + while True: + if not no_pending: + no_pending = check_sme_pending() + time.sleep(10) + if not no_running: + no_running = check_sme_running() + time.sleep(10) + + +if __name__ == "__main__": + listen() diff --git a/tools/queue_submit.py b/tools/queue_submit.py new file mode 100644 index 0000000..8f8c2b0 --- /dev/null +++ b/tools/queue_submit.py @@ -0,0 +1,144 @@ +import re +import time +import subprocess +from pathlib import Path + +from loguru import logger + +from smoe.utils.notification import send_to_wechat + +from check_killed import get_jobstate + + +logger.add("logs/queue_submit.log") + + +def run_command(command): + try: + logger.info(f"Running cmd: {command}") + subprocess.run( + command, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except subprocess.CalledProcessError as e: + logger.info(f"An error occurred: {e}") + + +def get_jobid(filepath): + content = Path(filepath).read_text(encoding="utf8") + obj = re.search(r"srun: job (\d+) queued and waiting for resources", content) + if obj is None: + return None + job_id = obj.group(1) + return job_id + + +if __name__ == "__main__": + task_list = [ + # CANCELLED+ + ('en_stack', '000028'), + ('en_wikipedia', '000004'), + ('en_wikipedia', '000006'), + ('en_wikipedia', '000009'), + ('github', '000024'), + ('github', '000029'), + ('en_wikipedia', '000010'), + ('en_wikipedia', '000012'), + ('en_wikipedia', '000024'), + ('en_wikipedia', '000026'), + ('en_wikipedia', '000029'), + ('github', '000004'), + ('en_wikipedia', '000000'), + ('en_wikipedia', '000002'), + ('github', '000014'), + ('github', '000016'), + ('github', '000019'), + ('github', '000020'), + ('github', '000022'), + # error = 10, + ('github', '000011'), + ('github', '000013'), + ('github', '000027'), + ('github', '000007'), + ('github', '000008'), + ('github', '000010'), + ('github', '000012'), + ('github', '000026'), + ('github', '000006'), + ('github', '000009'), + # un-processed, + ("en_cc", "000000"), + ("en_cc", "000001"), + ("en_cc", "000002"), + ("en_cc", "000003"), + ("en_cc", "000004"), + ("en_cc", "000005"), + ("en_cc", "000006"), + ("en_cc", "000007"), + ("en_cc", "000008"), + ("en_cc", "000009"), + ("en_cc", "000010"), + ("en_cc", "000011"), + ("en_cc", "000012"), + ("en_cc", "000013"), + ("en_cc", "000014"), + ("en_cc", "000015"), + ("en_cc", "000016"), + ("en_cc", "000017"), + ("en_cc", "000018"), + ("en_cc", "000019"), + ("en_cc", "000020"), + ("en_cc", "000021"), + ("en_cc", "000022"), + ("en_cc", "000023"), + ("en_cc", "000024"), + ("en_cc", "000025"), + ("en_cc", "000026"), + ("en_cc", "000027"), + ("en_cc", "000028"), + ("en_cc", "000029"), + ] + data_dir = Path( + "/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_llama_middle_parts" + ) + out_dir = Path( + "/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral_middle_parts" + ) + + def submit_job(data_type, part_id): + run_command( + f"nohup srun -p MoE_T -N1 -n1 -c 32 " + f"python -m smoe.utils.tokenize " + f"-f jsonl -c input_ids " + "-s /mnt/petrelfs/share_data/zhutong/models/llama2_7B " + "-t /mnt/petrelfs/share_data/zhutong/models/Mistral-7B-v0.1 " + f"-i {str(data_dir / data_type / f'part-{part_id}')} " + f"-o {str(out_dir / data_type / f'part-{part_id}')} " + f"1>logs/tokenize-{data_type}-part-{part_id}.log 2>&1 &" + ) + + wait_seconds = 5 + check_times = 10 + while len(task_list) > 0: + data_type, part_id = task_list.pop(0) + filepath = data_dir / data_type / f"part-{part_id}" + logger.info(f"Processing: {filepath}") + submit_job(data_type, part_id) + check_times = 10 + time.sleep(5) + jobstate = "PENDING" + while jobstate != "RUNNING" and jobstate != "COMPLETED": + if "CANCELLED" in jobstate: + send_to_wechat(f"Job {data_type}-{part_id} is cancelled, resubmit") + submit_job(data_type, part_id) + if check_times <= 0: + wait_seconds = 600 + time.sleep(wait_seconds) + job_id = get_jobid(f"logs/tokenize-{data_type}-part-{part_id}.log") + jobstate = get_jobstate(job_id) + logger.info(f"Check job: {job_id} - {jobstate}") + check_times -= 1 diff --git a/tools/scl_jobs.sh b/tools/scl_jobs.sh new file mode 100644 index 0000000..9c705e9 --- /dev/null +++ b/tools/scl_jobs.sh @@ -0,0 +1,45 @@ +# scancel from the list below + +list=( + "2384204" + "2384206" + "2384207" + "2384208" + "2384209" + "2384210" + "2384211" + "2384213" + "2384215" + "2384216" + "2384217" + "2384218" + "2384220" + "2384221" + "2384222" + "2384223" + "2384226" + "2384228" + "2384230" + "2384231" + "2384233" + "2384234" + "2384264" + "2384262" + "2384261" + "2384259" + "2384257" + "2384255" + "2384253" + "2384251" + "2384249" + "2384244" + "2384242" + "2384240" + "2384238" + "2384236" +) + +for i in "${list[@]}" +do + scancel $i +done From 0209e57ea3f1b8695a6e182d1ebcb137c1f36009 Mon Sep 17 00:00:00 2001 From: Tong Zhu Date: Sun, 17 Dec 2023 16:45:18 +0800 Subject: [PATCH 07/12] add mixtral support --- ...s_8_2_sheared_llama_portion_fluency_sf4.sh | 24 +- scripts/cpt/test/test_conn.sh | 34 +++ smoe/entrypoint/cpt/cpt_fpt.py | 14 +- smoe/models/mixtral/configuration_mixtral.py | 2 +- smoe/models/mixtral/modeling_mixtral.py | 223 +++++++++++++++++- smoe/utils/notification.py | 84 +++---- tools/cp_files.py | 32 +++ 7 files changed, 342 insertions(+), 71 deletions(-) create mode 100644 scripts/cpt/test/test_conn.sh create mode 100644 tools/cp_files.py diff --git a/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh index bed9b34..12a670b 100644 --- a/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh +++ b/scripts/cpt/8_2/mixtral_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh @@ -1,15 +1,15 @@ #!/usr/bin/bash -#SBATCH --job-name=mxitral_random_split_112gpus_8_2 -#SBATCH --output=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log -#SBATCH --error=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log +#SBATCH --job-name=mixtral_random_split_112gpus_8_2 +#SBATCH --output=/mnt/petrelfs/share_data/zhutong/runs/mixtral_random_split_112gpus_8_2/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/zhutong/runs/mixtral_random_split_112gpus_8_2/%x-%j.log #SBATCH --partition=MoE_T #SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=64 +#SBATCH --cpus-per-task=26 #SBATCH --mem=0 -#SBATCH --nodes=2 +#SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved @@ -18,8 +18,8 @@ source ~/anaconda3/bin/activate smoe { - num_nodes=2 # should match with --nodes - num_gpu_per_node=8 # should match with --gres + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres # #cpu/#num_gpu_per_node export OMP_NUM_THREADS=32 @@ -33,8 +33,8 @@ source ~/anaconda3/bin/activate smoe comment="mistral 7B, random 2/8, sheared llama data portion" pretrained_model=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2 tokenizer_path=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2 - dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg - validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + dataset_dir=/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral + validation_dir=/mnt/petrelfs/share_data/zhutong/data/mixtral_val_set_tokenized lr=2e-4 final_lr_portion=0.1 @@ -68,7 +68,7 @@ source ~/anaconda3/bin/activate smoe echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" data_cache=resources/cache - base_dir="/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2" + base_dir="/mnt/petrelfs/share_data/zhutong/runs/mixtral_random_split_112gpus_8_2" output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID mkdir -p $output_dir echo "output_dir: $output_dir" @@ -86,7 +86,7 @@ source ~/anaconda3/bin/activate smoe head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) echo "Node: $head_node" echo "Node IP: $head_node_ip" - echo "Node list: $SLURM_JOB_NODELIS" + echo "Node list: $nodes" srun torchrun \ --nnodes ${num_nodes} \ @@ -94,7 +94,7 @@ source ~/anaconda3/bin/activate smoe --node_rank $SLURM_NODEID \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ - --rdzv_endpoint $head_node:29518 \ + --rdzv_endpoint $head_node:29519 \ smoe/entrypoint/cpt/cpt_fpt.py \ --prob_map "sheared_llama" \ --num_selects ${num_selects} \ diff --git a/scripts/cpt/test/test_conn.sh b/scripts/cpt/test/test_conn.sh new file mode 100644 index 0000000..f946bea --- /dev/null +++ b/scripts/cpt/test/test_conn.sh @@ -0,0 +1,34 @@ +# !/usr/bin/bash + +# SBATCH --job-name=test_conn +# SBATCH --output=logs/test_conn.log +# SBATCH --error=logs/test_conn.log + +# SBATCH --partition=MoE_T +# SBATCH --ntasks-per-node=1 +# SBATCH --cpus-per-task=26 +# SBATCH --mem=0 + +# SBATCH --nodes=8 +# SBATCH --gres=gpu:1 +# SBATCH --quotatype=reserved + +# srun -p MoE_T -N8 -n8 --gres=gpu:1 -w HOST-10-140-60-[134,141,163,180-181,184] torchrun --nnodes 8 --nproc_per_node 1 tests/entrypoint/test_conn.py +# $ srun -p MoE_T -N8 -n8 --gres=gpu:1 -w HOST-10-140-60-[134,141,163,180-181,184] bash scripts/cpt/test/test_conn.sh + +nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +echo "Node: $head_node" +echo "Node IP: $head_node_ip" +echo "Node list: $nodes" + +torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29519 \ + tests/entrypoint/test_conn.py diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index da4a5e5..41c772c 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -69,7 +69,7 @@ logger = logging.getLogger(__name__) -@wechat_sender(msg_prefix="CPT Training") +# @wechat_sender(msg_prefix="CPT Training") def main(): model_args, data_args, training_args = parse_args( ModelArguments, DataArguments, EnhancedTrainingArguments @@ -157,6 +157,8 @@ def main(): or model_args.model_type == "llama_moe_residual" ): ConfigClass = LlamaMoEResidualConfig + elif model_args.config_name == "mixtral" or model_args.model_type == "mixtral": + ConfigClass = MixtralConfig if model_args.config_name: config = ConfigClass.from_pretrained(model_args.config_name, **config_kwargs) @@ -179,6 +181,10 @@ def main(): if training_args.debug_mode: config.num_hidden_layers = 2 + if model_args.model_type == "mixtral" or model_args.model_name_or_path == "mixtral": + config.num_experts_per_tok = model_args.num_selects + config.output_router_logits = True + tokenizer_kwargs = { "cache_dir": model_args.cache_dir, "use_fast": model_args.use_fast_tokenizer, @@ -280,8 +286,10 @@ def main(): # model.half() # model.to(torch_dtype) - # TODO (tzhu): add flash-attn for mixtral - model: LlamaForCausalLM | LlamaMoEForCausalLM | LlamaMoEResidualForCausalLM = ( + if isinstance(config, MixtralConfig): + config._attn_implementation = "flash_attention_2" + + model: LlamaForCausalLM | LlamaMoEForCausalLM | LlamaMoEResidualForCausalLM | MixtralForCausalLM = ( ModelClass.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/smoe/models/mixtral/configuration_mixtral.py b/smoe/models/mixtral/configuration_mixtral.py index 207f163..38fd675 100644 --- a/smoe/models/mixtral/configuration_mixtral.py +++ b/smoe/models/mixtral/configuration_mixtral.py @@ -183,7 +183,7 @@ def __init__( self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef - self.scale_factor = kwargs.pop("scale_factor", 1.0) + self.score_scale_factor = kwargs.pop("score_scale_factor", 4.0) # Attention implementation to use, if relevant. self._attn_implementation_internal = kwargs.pop("attn_implementation", None) diff --git a/smoe/models/mixtral/modeling_mixtral.py b/smoe/models/mixtral/modeling_mixtral.py index fa8c5f3..da21a9b 100644 --- a/smoe/models/mixtral/modeling_mixtral.py +++ b/smoe/models/mixtral/modeling_mixtral.py @@ -21,7 +21,10 @@ import inspect import math import warnings -from typing import List, Optional, Tuple, Union +import importlib +from dataclasses import dataclass +from packaging import version +from typing import List, Optional, Tuple, Union, Callable import torch import torch.nn.functional as F @@ -30,19 +33,16 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, + ModelOutput, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + is_torch_available, ) from transformers.utils.import_utils import is_torch_fx_available @@ -53,6 +53,183 @@ from .configuration_mixtral import MixtralConfig +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") + + +def _is_package_available( + pkg_name: str, return_version: bool = False +) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + logger.debug(f"Detected {pkg_name} version {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +def is_flash_attn_2_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn"): + return False + + # Let's add an extra check to see if cuda is available + import torch + + if not torch.cuda.is_available(): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.0.4" + ) + else: + return False + + +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + + +def is_flash_attn_available(): + logger.warning( + "Using `is_flash_attn_available` is deprecated and will be removed in v4.38. " + "Please use `is_flash_attn_2_available` instead." + ) + return is_flash_attn_2_available() + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + @property + def balance_loss(self): + return self.aux_loss + + @property + def num_dropped_tokens(self): + return [torch.tensor(-1)] * 32 + + @property + def gate_load(self): + return [torch.tensor(-1)] * 32 + + @property + def gate_importance(self): + return [torch.tensor(-1)] * 32 + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -71,11 +248,6 @@ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "MixtralConfig" - - def load_balancing_loss_func( gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2 ) -> float: @@ -789,6 +961,7 @@ def __init__(self, config): self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok + self.score_scale_factor = config.score_scale_factor # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) @@ -843,6 +1016,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = ( expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + * self.score_scale_factor ) # However `index_add_` only support torch tensors for indexing so we'll use @@ -1211,8 +1385,16 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs: tuple = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, @@ -1221,6 +1403,17 @@ def forward( output_router_logits, use_cache, ) + + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # output_router_logits, + # use_cache, + # ) else: layer_outputs = decoder_layer( hidden_states, @@ -1310,6 +1503,10 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MixtralModel): + module.gradient_checkpointing = value + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC diff --git a/smoe/utils/notification.py b/smoe/utils/notification.py index 050e8d7..4517c23 100644 --- a/smoe/utils/notification.py +++ b/smoe/utils/notification.py @@ -127,31 +127,57 @@ def wrapper_sender(*args, **kwargs): if webhook_url: requests.post(webhook_url, json=msg_template) - try: - value = func(*args, **kwargs) + try: + value = func(*args, **kwargs) + + if master_process: + end_time = datetime.datetime.now() + elapsed_time = end_time - start_time + contents = [ + "Your training is complete 🎉", + "Machine name: %s" % host_name, + "Main call: %s" % func_name, + f"Job {get_slurm_job_name()}", + "Starting date: %s" % start_time.strftime(DATE_FORMAT), + "End date: %s" % end_time.strftime(DATE_FORMAT), + "Training duration: %s" % str(elapsed_time), + ] + + try: + str_value = str(value) + contents.append("Main call returned value: %s" % str_value) + except Exception: + contents.append( + "Main call returned value: %s" + % "ERROR - Couldn't str the returned value." + ) + + msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( + contents + ) + logger.info(f"{json.dumps(msg_template, ensure_ascii=False)}") + if webhook_url: + requests.post(webhook_url, json=msg_template) + + return value - if master_process: + except Exception as ex: end_time = datetime.datetime.now() elapsed_time = end_time - start_time contents = [ - "Your training is complete 🎉", + "Your training has crashed ☠️", "Machine name: %s" % host_name, "Main call: %s" % func_name, f"Job {get_slurm_job_name()}", "Starting date: %s" % start_time.strftime(DATE_FORMAT), - "End date: %s" % end_time.strftime(DATE_FORMAT), - "Training duration: %s" % str(elapsed_time), + "Crash date: %s" % end_time.strftime(DATE_FORMAT), + "Crashed training duration: %s\n\n" % str(elapsed_time), + "Here's the error:", + "%s\n\n" % ex, + "Traceback:", + "%s" % traceback.format_exc(), ] - try: - str_value = str(value) - contents.append("Main call returned value: %s" % str_value) - except Exception: - contents.append( - "Main call returned value: %s" - % "ERROR - Couldn't str the returned value." - ) - msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( contents ) @@ -159,33 +185,7 @@ def wrapper_sender(*args, **kwargs): if webhook_url: requests.post(webhook_url, json=msg_template) - return value - - except Exception as ex: - end_time = datetime.datetime.now() - elapsed_time = end_time - start_time - contents = [ - "Your training has crashed ☠️", - "Machine name: %s" % host_name, - "Main call: %s" % func_name, - f"Job {get_slurm_job_name()}", - "Starting date: %s" % start_time.strftime(DATE_FORMAT), - "Crash date: %s" % end_time.strftime(DATE_FORMAT), - "Crashed training duration: %s\n\n" % str(elapsed_time), - "Here's the error:", - "%s\n\n" % ex, - "Traceback:", - "%s" % traceback.format_exc(), - ] - - msg_template["text"]["content"] = f"{msg_prefix}\n" + "\n".join( - contents - ) - logger.info(f"{json.dumps(msg_template, ensure_ascii=False)}") - if webhook_url: - requests.post(webhook_url, json=msg_template) - - raise ex + raise ex return wrapper_sender diff --git a/tools/cp_files.py b/tools/cp_files.py new file mode 100644 index 0000000..099299e --- /dev/null +++ b/tools/cp_files.py @@ -0,0 +1,32 @@ +import os +import shutil +from pathlib import Path + +from tqdm import tqdm + + +def copy_files(src_folder: str, dest_folder: str): + src_folder = Path(src_folder) + dest_folder = Path(dest_folder) + dest_folder.mkdir(parents=True, exist_ok=True) + files = src_folder.glob("**/*.jsonl") + for file in tqdm(files): + dest_file = dest_folder / file.name + if not dest_file.exists(): + # print(str(file), str(dest_file)) + # shutil.copy2(str(file), str(dest_file)) + # link the file to dest_folder + # os.link(str(file), str(dest_file)) + os.symlink(str(file), str(dest_file)) + + +if __name__ == "__main__": + # copy_files( + # "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed/c4_split_fluency/", + # "/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg/en_c4/" + # ) + for domain in ["en_book", "en_c4", "en_cc", "en_arxiv", "en_wikipedia", "en_stack", "github"]: + copy_files( + f"/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral_middle_parts/{domain}", + f"/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral/{domain}", + ) From d7d098944c1f9943591b2d7ac9783e29ebd837e0 Mon Sep 17 00:00:00 2001 From: zhutong Date: Sat, 23 Dec 2023 21:42:05 +0800 Subject: [PATCH 08/12] update readme, add mistral-moe support (buggy), fix `use_cache` when inference --- README.md | 52 +- .../llama_moe/configuration_llama_moe.py | 13 +- smoe/models/llama_moe/modeling_llama_moe.py | 37 +- .../models/llama_moe/modeling_llama_moe_hf.py | 1815 +++++++++++++++++ smoe/models/mistral/__init__.py | 66 + smoe/models/mistral/configuration_mistral.py | 303 +++ smoe/models/mistral/modeling_mistral.py | 1488 ++++++++++++++ smoe/utils/param_estimation.py | 2 + 8 files changed, 3710 insertions(+), 66 deletions(-) create mode 100644 smoe/models/llama_moe/modeling_llama_moe_hf.py create mode 100644 smoe/models/mistral/__init__.py create mode 100644 smoe/models/mistral/configuration_mistral.py create mode 100644 smoe/models/mistral/modeling_mistral.py diff --git a/README.md b/README.md index 4445f67..c0dc348 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ LLaMA-MoE favicon
📢 A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!! @@ -16,40 +16,36 @@ We build LLaMA-MoE with the following two steps: 2. Continually pre-train the initialized MoE model with an optimized data sampling weights from [Sheared LLaMA](https://arxiv.org/abs/2310.06694) and filtered datasets from [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama). -| Model | \#Activated Experts | \#Experts | \#Activated Params | \#Total Prams | Links | -| :----------------- | :-----------------: | :-------: | :----------------: | :-----------: | :----------------------------------------------------------------------------------------------: | -| OPT-2.7B | - | - | 2.7B | 2.7B | ([Zhang et al., 2022](https://huggingface.co/facebook/opt-2.7b)) | -| Pythia-2.8B | - | - | 2.8B | 2.8B | ([Biderman et al., 2023](https://huggingface.co/EleutherAI/pythia-2.8b)) | -| INCITE-BASE-3B | - | - | 2.8B | 2.8B | ([Together Computer, 2023](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1)) | -| Open-LLaMA-3B-v2 | - | - | 3.4B | 3.4B | ([Geng et al., 2023](https://huggingface.co/openlm-research/open_llama_3b_v2)) | -| Sheared-LLaMA-2.7B | - | - | 2.7B | 2.7B | ([Xia et al., 2023](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B)) | -| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | 6.7B | [[HF Weights]](https://huggingface.co) | -| **LLaMA-MoE-3.5B** | 4 | 16 | 3.5B | 6.7B | [[HF Weights]](https://huggingface.co) | +| Model | \#Activated Experts | \#Experts | \#Activated Params | Links | +| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: | +| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | +| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | +| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | -| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | -| :----------------- | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :-------: | -| OPT-2.7B | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 55.9 | 10.7 | 25.8 | 49.6 | -| Pythia-2.8B | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 54.4 | 8.6 | 26.8 | 50.6 | -| INCITE-BASE-3B | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 55.6 | 15.2 | 27.2 | 52.8 | -| Open-LLaMA-3B-v2 | **88.0** | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 59.5 | 16.0 | 26.8 | 54.9 | -| Sheared-LLaMA-2.7B | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 59.7 | 17.7 | **27.3** | 55.6 | -| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 59.7 | 17.0 | 26.8 | 54.8 | -| **LLaMA-MoE-3.5B** | 87.6 | **77.9** | **65.5** | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **63.2** | **20.3** | 26.8 | **57.2 ** | +| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | +| :------------------------------------------------------------------------------------ | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :-----: | +| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | 50.3 | +| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | 51.5 | +| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | 53.7 | +| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | 55.6 | +| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | 56.4 | +| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | 55.5 | +| **LLaMA-MoE-3.5B (4/16)** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | 57.7 | +| **LLaMA-MoE-3.5B (2/8)** | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | 57.6 |

🚀 QuickStart

```python import torch -from transformers import AutoTokenizer -from smoe.models.llama_moe import LlamaMoEForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM - -model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2342244/checkpoint-13600/" -tokenizer = AutoTokenizer.from_pretrained(model_dir) -model = LlamaMoEForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) +model_dir = "llama-moe/LLaMA-MoE-v1-3_5B-2_8" +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True) +model.eval() model.to("cuda:0") input_text = "Suzhou is famous of" @@ -61,7 +57,7 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) # Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three ``` -

🚧 Expert Initialization

+

🚧 Expert Construction

- Neuron-Independent - IndependentRandom: `bash ./scripts/moefication/split/run_split_random.sh` @@ -70,7 +66,7 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) - SharingInner: `bash ./scripts/moefication/split/run_split_gradient.sh` - SharingInter: `bash ./scripts/moefication/split/run_split_gradient_residual.sh` -For more information, please refer to [Expert Initialization docs](docs/moefication/README.md). +For more information, please refer to [Expert Construction docs](docs/moefication/README.md).

🚅 Continual Pre-training

diff --git a/smoe/models/llama_moe/configuration_llama_moe.py b/smoe/models/llama_moe/configuration_llama_moe.py index 5c25e5f..f659a27 100644 --- a/smoe/models/llama_moe/configuration_llama_moe.py +++ b/smoe/models/llama_moe/configuration_llama_moe.py @@ -1,11 +1,4 @@ -""" LLaMA model configuration""" - from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class LlamaMoEConfig(PretrainedConfig): @@ -31,11 +24,11 @@ def __init__( pretraining_tp=1, tie_word_embeddings=False, rope_scaling=None, - #### -------- moe expert configs -------- #### + # -------- moe expert configs -------- num_experts=16, num_selects=4, size_experts=None, - #### -------- moe gate configs -------- #### + # -------- moe gate configs -------- gate_type="TopKBalancedNoisyGate", gate_network="mlp", gate_use_softmax=True, @@ -44,7 +37,7 @@ def __init__( gate_add_noise=True, # TopKBalancedNoisyGate gate_noise_epsilon=1e-2, - #### -------- moe calculator configs -------- #### + # -------- moe calculator configs -------- calculator_type="UniversalCalculator", multiply_gate_scores=True, score_scale_factor=1.0, diff --git a/smoe/models/llama_moe/modeling_llama_moe.py b/smoe/models/llama_moe/modeling_llama_moe.py index 440fe26..4666efa 100644 --- a/smoe/models/llama_moe/modeling_llama_moe.py +++ b/smoe/models/llama_moe/modeling_llama_moe.py @@ -28,18 +28,6 @@ _CONFIG_FOR_DOC = "LlamaMoEConfig" -@dataclass -class MoEDecoderLayerOutput(ModelOutput): - # zhutong: do not change the order of these fields!! - hidden_states: Optional[torch.FloatTensor] = None - balance_loss: Optional[float] = None - num_dropped_tokens: Optional[torch.Tensor] = None - gate_load: Optional[list[torch.Tensor]] = None - gate_importance: Optional[list[torch.Tensor]] = None - self_attn_weights: Optional[torch.FloatTensor] = None - present_key_value: Optional[torch.FloatTensor] = None - - @dataclass class BaseMoEModelOutputWithPast(ModelOutput): """ @@ -121,7 +109,7 @@ def forward( past_key_value=None, output_attentions=False, use_cache=False, - ) -> MoEDecoderLayerOutput: + ) -> tuple: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -154,12 +142,6 @@ def forward( if use_cache: outputs += (present_key_value,) - for i, _o in enumerate(outputs): - if not isinstance(_o, torch.Tensor): - raise RuntimeError( - f"outputs[{i}]({type(_o)}) should be torch.Tensor to support grad ckpt" - ) - return outputs def set_moe_num_selects(self, num_selects): @@ -357,21 +339,20 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, ) - layer_outputs = MoEDecoderLayerOutput(*layer_outputs) - hidden_states = layer_outputs.hidden_states - if layer_outputs.balance_loss is not None: - balance_loss += layer_outputs.balance_loss + hidden_states = layer_outputs[0] + if layer_outputs[1] is not None: + balance_loss += layer_outputs[1] if use_cache: - next_decoder_cache += (layer_outputs.present_key_value,) + next_decoder_cache += (layer_outputs[6 if output_attentions else 5],) if output_attentions: - all_self_attns += (layer_outputs.self_attn_weights,) + all_self_attns += (layer_outputs[5],) - num_dropped_tokens += (layer_outputs.num_dropped_tokens,) - gate_load += (layer_outputs.gate_load,) - gate_importance += (layer_outputs.gate_importance,) + num_dropped_tokens += (layer_outputs[2],) + gate_load += (layer_outputs[3],) + gate_importance += (layer_outputs[4],) hidden_states = self.norm(hidden_states) diff --git a/smoe/models/llama_moe/modeling_llama_moe_hf.py b/smoe/models/llama_moe/modeling_llama_moe_hf.py new file mode 100644 index 0000000..da79503 --- /dev/null +++ b/smoe/models/llama_moe/modeling_llama_moe_hf.py @@ -0,0 +1,1815 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.distributions.normal import Normal +from transformers.activations import ACT2FN +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from .configuration_llama_moe import LlamaMoEConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaMoEConfig" + + +@dataclass +class CalculatorOutput(ModelOutput): + hidden_states: Optional[torch.FloatTensor] = None + num_dropped_tokens: Optional[int] = None + + +@dataclass +class BaseMoEModelOutputWithPast(ModelOutput): + """ + Args: + num_dropped_tokens: layer idx to the number of dropped tokens + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + balance_loss: Optional[float] = None + num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None + gate_load: Optional[Tuple[list]] = None + gate_importance: Optional[Tuple[list]] = None + + +@dataclass +class MoECausalLMOutputWithPast(CausalLMOutputWithPast): + balance_loss: Optional[float] = None + num_dropped_tokens: Optional[Tuple[int]] = None + gate_load: Optional[Tuple[list[torch.Tensor]]] = None + gate_importance: Optional[Tuple[list[torch.Tensor]]] = None + + +@dataclass +class MoEMlpOutput(ModelOutput): + hidden_states: Optional[torch.FloatTensor] = None + balance_loss: Optional[torch.FloatTensor] = None + num_dropped_tokens: Optional[int] = None + gate_load: Optional[list] = None + gate_importance: Optional[list] = None + + +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], + dim=-1, + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaMoEConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.pretraining_tp, dim=2 + ) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class TopKBalancedNoisyGate(nn.Module): + def __init__( + self, + input_size, + num_experts, + num_selects, + gate_network="mlp", + use_softmax=True, + use_balance=True, + balance_loss_weight=1e-2, + add_noise=True, + noise_epsilon=1e-2, + ): + super(TopKBalancedNoisyGate, self).__init__() + assert num_selects <= num_experts + self.input_size = input_size + self.num_experts = num_experts + self.num_selects = num_selects + + self.gate_network_type = gate_network + self.gate_network = self.get_gate_network(gate_network, input_size, num_experts) + + self.use_softmax = use_softmax + self.softmax = nn.Softmax(1) + + self.use_balance = use_balance + self.balance_loss_weight = balance_loss_weight + + # add_noise + self.add_noise = add_noise + self.noise_epsilon = noise_epsilon + self.warned = False + if self.add_noise: + self.weight_noise = nn.Linear(input_size, num_experts, bias=False) + self.weight_noise.weight.data = torch.zeros( + (num_experts, input_size), + requires_grad=True, + device=self.weight_noise.weight.data.device, + dtype=self.weight_noise.weight.data.dtype, + ) + self.mean = 0.0 + self.std = 1.0 + self.normal = Normal(self.mean, self.std) + self.softplus = nn.Softplus() + + self.reset_parameters() + + def get_gate_network(self, gate_type, input_size, num_experts): + gate_type = gate_type.lower() + + if gate_type == "linear": + gate_network = nn.Linear(input_size, num_experts, bias=False) + nn.init.zeros_(gate_network.weight) + elif gate_type == "mlp": + gate_network = torch.nn.Sequential( + torch.nn.Linear(input_size, num_experts, bias=False), + torch.nn.Tanh(), + torch.nn.Linear(num_experts, num_experts, bias=False), + ) + else: + raise ValueError(f"Unexpected gate_type: {gate_type}.") + + return gate_network + + def reset_gate_network(self): + if "gate_network_type" not in vars(self): + raise KeyError(f"{type(self)} does not have a gate network.") + else: + self.gate_network = self.get_gate_network( + self.gate_network_type, self.input_size, self.num_experts + ) + + def reset_parameters(self): + if self.add_noise: + nn.init.zeros_(self.weight_noise.weight) + # nn.init.zeros_(self.weight_noise) + + def cv_squared(self, x, eps=1e-10): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`.s + """ + if x.shape[0] == 1: + return torch.tensor(0.0, device=x.device) + return x.float().var() / (x.float().mean() ** 2 + eps) + + def forward(self, x): + logits_gate = self.gate_network(x) + if self.training and self.add_noise: + noise_mm = self.weight_noise(x) + noise_control = self.softplus(noise_mm) + self.noise_epsilon + logits_noise = torch.randn_like(logits_gate) * noise_control + logits = logits_gate + logits_noise + else: + logits = logits_gate + + top_logits, top_indices = logits.topk( + min(self.num_selects + 1, self.num_experts), dim=1 + ) # 选择并排序前k+1个权重 + top_k_logits = top_logits[:, : self.num_selects] + top_k_indices = top_indices[:, : self.num_selects] + top_k_scores = ( + self.softmax(top_k_logits.to(torch.float32)) + if self.use_softmax + else top_k_logits + ) + top_k_scores = top_k_scores.to(logits.dtype) + + zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) + scores_filtered = zeros.scatter( + dim=1, index=top_k_indices, src=top_k_scores + ) # shape(batch_size, num_experts) + importance = scores_filtered.sum(0) # shape(num_experts) + + if self.training: + if self.add_noise and self.num_selects != self.num_experts: + batch_size = top_logits.size(0) + m = top_logits.size(1) + top_values_flat = top_logits.flatten() + threshold_positions_if_in = ( + torch.arange(batch_size, device=x.device) * m + self.num_selects + ) + threshold_if_in = torch.unsqueeze( + torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 + ) + is_in = torch.gt(logits_noise, threshold_if_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = torch.unsqueeze( + torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 + ) + # is each value currently in the top k. + prob_if_in = self.normal.cdf( + (logits_gate - threshold_if_in) / noise_control + ) + prob_if_out = self.normal.cdf( + (logits_gate - threshold_if_out) / noise_control + ) + prob = torch.where(is_in, prob_if_in, prob_if_out) + load = prob.sum(0) + else: + load = (scores_filtered > 0).sum(0) + if not self.add_noise and not self.warned: + warnings.warn( + 'Gradient-trackable implementation for load calculation is only available when "add_noise=True". ' + 'Training without noise will block the gradient from "load" path and lead to inconsistency in optimization objectives.' + ) + self.warned = True + else: + load = (scores_filtered > 0).sum(0) + + if self.use_balance: + balance_loss = self.cv_squared(importance) + self.cv_squared(load) + balance_loss *= self.balance_loss_weight + else: + balance_loss = torch.tensor(-100.0, device=x.device) + + return { + "topK_indices": top_k_indices, + "topK_scores": top_k_scores, + "balance_loss": balance_loss, + "load": load, + "importance": importance, + } + + +class LinearGLUExperts(nn.Module): + """ + Modified from transformers.models.llama.modeling_llama.LlamaMLP + """ + + __constants__ = [ + "bias", + "in_features", + "hidden_features", + "out_features", + "hidden_act", + "num_experts", + "size_experts", + ] + + def __init__( + self, + in_features, + hidden_features, + out_features, + hidden_act, + num_experts, + size_experts=None, + bias=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super(LinearGLUExperts, self).__init__() + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + self.hidden_act = hidden_act + self.num_experts = num_experts + + if size_experts is None: + # all experts share the same number of hidden neurons + assert hidden_features % num_experts == 0 + size_per_expert = hidden_features // num_experts + size_experts = [size_per_expert for _ in range(num_experts)] + else: + # use specified expert sizes + assert ( + len(size_experts) == num_experts + and sum(size_experts) == hidden_features + ) + self.size_experts = size_experts + + self.act_fn = ACT2FN[hidden_act] + + self.weight_gate = nn.ParameterList() + self.weight_up = nn.ParameterList() + self.weight_down = nn.ParameterList() + + for i in range(num_experts): + # this matrix will be transposed when performing linear forwarding + this_expert_weight_gate = nn.Parameter( + torch.empty((size_experts[i], in_features), **factory_kwargs) + ) + # this matrix will be transposed when performing linear forwarding + this_expert_weight_up = nn.Parameter( + torch.empty((size_experts[i], in_features), **factory_kwargs) + ) + # this matrix will be transposed when performing linear forwarding + this_expert_weight_down = nn.Parameter( + torch.empty((out_features, size_experts[i]), **factory_kwargs) + ) + self.weight_gate.append(this_expert_weight_gate) + self.weight_up.append(this_expert_weight_up) + self.weight_down.append(this_expert_weight_down) + + if bias: + self.bias_gate = nn.ParameterList() + self.bias_up = nn.ParameterList() + self.bias_down = nn.ParameterList() + + for i in range(num_experts): + this_expert_bias_gate = nn.Parameter( + torch.empty((size_experts[i],), **factory_kwargs) + ) + this_expert_bias_up = nn.Parameter( + torch.empty((size_experts[i],), **factory_kwargs) + ) + this_expert_bias_down = nn.Parameter( + torch.empty((out_features,), **factory_kwargs) + ) + self.bias_gate.append(this_expert_bias_gate) + self.bias_up.append(this_expert_bias_up) + self.bias_down.append(this_expert_bias_down) + else: + self.register_parameter("bias_gate", None) + self.register_parameter("bias_up", None) + self.register_parameter("bias_down", None) + + self.reset_parameters() + + def reset_parameters(self): + for i in range(self.num_experts): + nn.init.kaiming_uniform_(self.weight_gate[i], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight_up[i], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight_down[i], a=math.sqrt(5)) + if self.bias_gate is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_gate[i]) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias_gate[i], -bound, bound) + if self.bias_up is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_up[i]) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias_up[i], -bound, bound) + if self.bias_down is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_down[i]) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias_down[i], -bound, bound) + + def forward(self, input, i): + gate = self.act_fn( + F.linear( + input, + self.weight_gate[i], + self.bias_gate[i] if self.bias_gate is not None else None, + ) + ) + up = F.linear( + input, + self.weight_up[i], + self.bias_up[i] if self.bias_up is not None else None, + ) + down = F.linear( + gate * up, + self.weight_down[i], + self.bias_down[i] if self.bias_down is not None else None, + ) + return down + + def extra_repr(self): + return ( + "in_features={}, hidden_features={}, out_features={}, hidden_act={}," + " num_experts={}, size_experts={}, bias={}".format( + self.in_features, + self.hidden_features, + self.out_features, + self.hidden_act, + self.num_experts, + self.size_experts, + self.bias_gate is not None, + ) + ) + + +class UniversalCalculator(nn.Module): + def __init__( + self, + experts: LinearGLUExperts, + multiply_gate_scores=True, + score_scale_factor=1.0, + add_weight_norm: bool = False, + ): + super(UniversalCalculator, self).__init__() + self.experts = experts + # TODO (zhutong): use vmap to boost the training efficiency + # self.experts_vmap = torch.vmap(self.experts) + self.multiply_gate_scores = multiply_gate_scores + self.score_scale_factor = score_scale_factor + self.num_experts = experts.num_experts + self.mlp_norm = None + if multiply_gate_scores and add_weight_norm: + raise NotImplementedError + + def reset_experts(self): + self.experts.reset_parameters() + + def forward( + self, x, topK_indices, topK_scores, expert_batch_size=None, **kwargs + ) -> CalculatorOutput: + batch_size = topK_indices.size(0) # topK_indices: (bsz*seq_len, num_selects) + num_selects = topK_indices.size(1) + topK_indices = topK_indices.flatten() # shape(batch_size*num_selects) + topK_scores = topK_scores.flatten() # shape(batch_size*num_selects) + batch_indices = torch.arange( + batch_size, device=topK_scores.device + ).repeat_interleave(num_selects) + + _, index_sorted_topK_indices = topK_indices.sort(0) + + sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices) + sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices) + + if expert_batch_size is None: + expert_batch_size = topK_indices.bincount( + minlength=self.num_experts + ).tolist() + + sorted_x = x.index_select(0, sorted_batch_indices) + split_x = torch.split(sorted_x, expert_batch_size, dim=0) + + expert_outputs = [ + self.experts(split_x[i], i) + for i in range(self.num_experts) + if split_x[i].shape[0] > 0 + ] + + # (bsz*seq_len*num_selects, hidden_size) + cat_expert_outputs = torch.cat(expert_outputs, 0) + output_dim = cat_expert_outputs.size(1) + if self.multiply_gate_scores: + if self.mlp_norm is None: + cat_expert_outputs = torch.mul( + cat_expert_outputs, + sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor, + ) + # cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * 1.0) + else: + cat_expert_outputs = torch.mul( + cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) + ) + cat_expert_outputs = self.mlp_norm(cat_expert_outputs) + + zeros = torch.zeros( + (batch_size, output_dim), + device=cat_expert_outputs.device, + dtype=cat_expert_outputs.dtype, + ) + y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs) + + return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0)) + + +class BaseMoELayer(nn.Module): + def __init__(self): + super(BaseMoELayer, self).__init__() + + self.gate: TopKBalancedNoisyGate + self.calculator: UniversalCalculator + + def _create_gate(self, **kwargs): + self.gate_type = kwargs.get("gate_type", "TopKBalancedNoisyGate") + + if self.gate_type == "TopKBalancedNoisyGate": # noisy gate + self.gate = TopKBalancedNoisyGate( + self.input_size, + self.num_experts, + self.num_selects, + gate_network=kwargs.get("gate_network", "mlp"), + use_softmax=kwargs.get("gate_use_softmax", True), + use_balance=kwargs.get("gate_use_balance", True), + balance_loss_weight=kwargs.get("gate_balance_loss_weight", 1e-2), + add_noise=kwargs.get("gate_add_noise", True), + noise_epsilon=kwargs.get("gate_noise_epsilon", 1e-2), + ) + else: + raise NotImplementedError + + def _create_calculator(self, experts, **kwargs): + self.calculator_type = kwargs.get("calculator_type", "UniversalCalculator") + + if self.calculator_type == "UniversalCalculator": # top K calculator + self.calculator = UniversalCalculator( + experts, + multiply_gate_scores=kwargs.get("multiply_gate_scores", True), + score_scale_factor=kwargs.get("score_scale_factor", 1.0), + add_weight_norm=kwargs.get("add_weight_norm", False), + ) + else: + raise NotImplementedError + + def forward(self, x) -> MoEMlpOutput: + original_shape = x.shape[:-1] + x = x.reshape(-1, self.input_size) + gate_outputs: dict = self.gate(x) + calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs) + y = calc_outs.hidden_states + y = y.reshape(original_shape + (self.output_size,)) + + return MoEMlpOutput( + hidden_states=y, + balance_loss=gate_outputs.get("balance_loss"), + num_dropped_tokens=calc_outs.num_dropped_tokens, + gate_load=gate_outputs.get("load", torch.tensor(-1)), + gate_importance=gate_outputs.get("importance", torch.tensor(-1)), + ) + + def set_num_selects(self, num_selects): + if "num_selects" not in vars(self.gate): + raise KeyError(f'{self.gate_type} does not have a key named "num_selects".') + elif num_selects > self.gate.num_experts: + raise ValueError( + 'The value of "num_selects" must satisfy "num_selects <= num_experts"!' + ) + elif self.gate_type in ("SwitchBalancedGate",): + raise ValueError( + f"{self.gate_type} doesn't support manually setting num_selects." + ) + else: + self.num_selects = num_selects + self.gate.num_selects = num_selects + + def set_gate_use_softmax(self, use_softmax): + if "use_softmax" not in vars(self.gate): + raise KeyError(f'{self.gate_type} does not have a key named "use_softmax".') + else: + self.gate.use_softmax = use_softmax + + def set_gate_use_balance(self, use_balance): + if "use_balance" not in vars(self.gate): + raise KeyError(f'{self.gate_type} does not have a key named "use_balance".') + else: + self.gate.use_balance = use_balance + + def set_gate_balance_loss_weight(self, balance_loss_weight): + if "balance_loss_weight" not in vars(self.gate): + raise KeyError( + f'{self.gate_type} does not have a key named "balance_loss_weight".' + ) + else: + self.gate.balance_loss_weight = balance_loss_weight + + def set_gate_add_noise(self, add_noise): + if "add_noise" not in vars(self.gate): + raise KeyError(f'{self.gate_type} does not have a key named "add_noise".') + else: + self.gate.add_noise = add_noise + + def set_gate_noise_epsilon(self, noise_epsilon): + if "noise_epsilon" not in vars(self.gate): + raise KeyError( + f'{self.gate_type} does not have a key named "noise_epsilon".' + ) + else: + self.gate.noise_epsilon = noise_epsilon + + def set_calculator_multiply_gate_scores(self, multiply_gate_scores): + if "multiply_gate_scores" not in vars(self.calculator): + raise KeyError( + f'{self.gate_type} does not have a key named "multiply_gate_scores".' + ) + else: + self.calculator.multiply_gate_scores = multiply_gate_scores + + def set_calculator_score_scale_factor(self, score_scale_factor): + if "score_scale_factor" not in vars(self.calculator): + raise KeyError( + f'{self.gate_type} does not have a key named "score_scale_factor".' + ) + else: + self.calculator.score_scale_factor = score_scale_factor + + def set_calculator_drop_tokens(self, drop_tokens): + if "drop_tokens" not in vars(self.calculator): + raise KeyError(f'{self.gate_type} does not have a key named "drop_tokens".') + elif ( + drop_tokens + and self.calculator.dropped_padding != "zero" + and self.input_size != self.output_size + ): + warnings.warn( + 'Setting "drop_tokens=True" without zero dropped padding when "input_size != output_size" will cause error!' + ) + else: + self.calculator.drop_tokens = drop_tokens + + def set_calculator_dropped_padding(self, dropped_padding): + if "dropped_padding" not in vars(self.calculator): + raise KeyError( + f'{self.gate_type} does not have a key named "dropped_padding".' + ) + elif dropped_padding not in self.calculator.available_dropped_padding_choices: + raise ValueError( + f"'dropped_padding' type not available! (available choices: {self.calculator.available_dropped_padding_choices})" + ) + elif ( + self.calculator.drop_tokens + and dropped_padding != "zero" + and self.input_size != self.output_size + ): + warnings.warn( + f'Setting "dropped_padding={dropped_padding}" with "drop_tokens=True" when "input_size != output_size" will cause error!' + ) + else: + self.calculator.dropped_padding = dropped_padding + + def set_calculator_capacity_factor(self, capacity_factor): + if "capacity_factor" not in vars(self.calculator): + raise KeyError( + f'{self.gate_type} does not have a key named "capacity_factor".' + ) + else: + self.calculator.capacity_factor = capacity_factor + + def reset_gate_network(self): + self.gate.reset_gate_network() + + def reset_experts(self): + self.calculator.reset_experts() + + +class LinearGLUMoELayer(BaseMoELayer): + def __init__( + self, + input_size, + hidden_size, + output_size, + hidden_act, + num_experts, + num_selects, + size_experts=None, + bias=True, + **kwargs, + ): + super(LinearGLUMoELayer, self).__init__() + assert num_selects <= num_experts + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.hidden_act = hidden_act + self.num_experts = num_experts + self.num_selects = num_selects + self.size_experts = size_experts + self.bias = bias + + experts = LinearGLUExperts( + input_size, + hidden_size, + output_size, + hidden_act, + num_experts, + size_experts=size_experts, + bias=bias, + ) + + self._create_gate(**kwargs) + self._create_calculator(experts, **kwargs) + + +class LlamaMoEDecoderLayer(nn.Module): + def __init__(self, config: LlamaMoEConfig, layer_index): + super().__init__() + + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + gating_config = { + # all gates + "gate_type": config.gate_type, + "gate_network": config.gate_network, + "gate_use_softmax": config.gate_use_softmax, + "gate_use_balance": config.gate_use_balance, + "gate_balance_loss_weight": config.gate_balance_loss_weight, + "gate_add_noise": config.gate_add_noise, + # TopKBalancedNoisyGate + "gate_noise_epsilon": config.gate_noise_epsilon, + } + calculator_config = { + # all calculators + "calculator_type": config.calculator_type, + "multiply_gate_scores": config.multiply_gate_scores, + "score_scale_factor": ( + config.score_scale_factor[layer_index] + if isinstance(config.score_scale_factor, list) + else config.score_scale_factor + ), + "add_weight_norm": config.add_weight_norm, + # SwitchDropTokenCalculator + "drop_tokens": config.drop_tokens, + "dropped_padding": config.dropped_padding, + "capacity_factor": config.capacity_factor, + } + + self.mlp = LinearGLUMoELayer( + input_size=self.hidden_size, + hidden_size=config.intermediate_size, + output_size=self.hidden_size, + hidden_act=config.hidden_act, + num_experts=config.num_experts, + num_selects=config.num_selects, + size_experts=( + config.size_experts[layer_index] + if config.size_experts is not None + else None + ), + bias=False, + **gating_config, + **calculator_config, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + output_attentions=False, + use_cache=False, + ) -> tuple: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + mlp_outs: MoEMlpOutput = self.mlp(hidden_states) + hidden_states = residual + mlp_outs.hidden_states + + outputs = ( + hidden_states, + mlp_outs.balance_loss, + mlp_outs.num_dropped_tokens, + mlp_outs.gate_load, + mlp_outs.gate_importance, + ) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + def set_moe_num_selects(self, num_selects): + self.mlp.set_num_selects(num_selects) + + def set_moe_gate_use_softmax(self, use_softmax): + self.mlp.set_gate_use_softmax(use_softmax) + + def set_moe_gate_use_balance(self, use_balance): + self.mlp.set_gate_use_balance(use_balance) + + def set_moe_gate_balance_loss_weight(self, balance_loss_weight): + self.mlp.set_gate_balance_loss_weight(balance_loss_weight) + + def set_moe_gate_add_noise(self, add_noise): + self.mlp.set_gate_add_noise(add_noise) + + def set_moe_gate_noise_epsilon(self, noise_epsilon): + self.mlp.set_gate_noise_epsilon(noise_epsilon) + + def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): + self.mlp.set_calculator_multiply_gate_scores(multiply_gate_scores) + + def set_moe_calculator_score_scale_factor(self, score_scale_factor): + self.mlp.set_calculator_score_scale_factor(score_scale_factor) + + def set_moe_calculator_drop_tokens(self, drop_tokens): + self.mlp.set_calculator_drop_tokens(drop_tokens) + + def set_moe_calculator_dropped_padding(self, dropped_padding): + self.mlp.set_calculator_dropped_padding(dropped_padding) + + def set_moe_calculator_capacity_factor(self, capacity_factor): + self.mlp.set_calculator_capacity_factor(capacity_factor) + + def reset_gate_network(self): + self.mlp.reset_gate_network() + + def reset_experts(self): + self.mlp.reset_experts() + + +class LlamaMoEPreTrainedModel(PreTrainedModel): + config_class = LlamaMoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaMoEDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaMoEModel): + module.gradient_checkpointing = value + + +class LlamaMoEModel(LlamaMoEPreTrainedModel): + def __init__(self, config: LlamaMoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at" + " the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + balance_loss = 0.0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing." + " Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + num_dropped_tokens = () + gate_load = () + gate_importance = () + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs: tuple = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs: tuple = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + if layer_outputs[1] is not None: + balance_loss += layer_outputs[1] + + if use_cache: + next_decoder_cache += (layer_outputs[6 if output_attentions else 5],) + + if output_attentions: + all_self_attns += (layer_outputs[5],) + + num_dropped_tokens += (layer_outputs[2],) + gate_load += (layer_outputs[3],) + gate_importance += (layer_outputs[4],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseMoEModelOutputWithPast( + last_hidden_state=hidden_states, + balance_loss=balance_loss, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + num_dropped_tokens=num_dropped_tokens, + gate_load=gate_load, + gate_importance=gate_importance, + ) + + def update_config(self): + self.config.vocab_size = self.config.vocab_size + self.config.max_position_embeddings = self.config.max_position_embeddings + # ↓↓↓↓↓↓↓↓↓↓↓↓ changed here ↓↓↓↓↓↓↓↓↓↓↓↓ # + self.config.hidden_size = self.layers[0].mlp.input_size + self.config.intermediate_size = self.layers[0].mlp.hidden_size + self.config.num_hidden_layers = len(self.layers) + self.config.num_attention_heads = self.layers[0].self_attn.num_heads + self.config.hidden_act = self.layers[0].mlp.hidden_act + # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ # + self.config.initializer_range = self.config.initializer_range + self.config.rms_norm_eps = self.config.rms_norm_eps + self.config.pretraining_tp = self.config.pretraining_tp + self.config.use_cache = self.config.use_cache + self.config.rope_scaling = self.config.rope_scaling + self.config._rope_scaling_validation() + + self.config.num_experts = self.layers[0].mlp.num_experts + self.config.num_selects = self.layers[0].mlp.num_selects + self.config.size_experts = [ + self.layers[i].mlp.calculator.experts.size_experts + for i in range(self.config.num_hidden_layers) + ] + + self.config.gate_type = vars(self.layers[0].mlp).get( + "gate_type", "TopKBalancedNoisyGate" + ) + self.config.gate_network = vars(self.layers[0].mlp.gate).get( + "gate_network_type", "mlp" + ) + self.config.gate_use_softmax = vars(self.layers[0].mlp.gate).get( + "use_softmax", True + ) + self.config.gate_use_balance = vars(self.layers[0].mlp.gate).get( + "use_balance", True + ) + self.config.gate_balance_loss_weight = vars(self.layers[0].mlp.gate).get( + "balance_loss_weight", 1e-2 + ) + self.config.gate_add_noise = vars(self.layers[0].mlp.gate).get( + "add_noise", True + ) + self.config.gate_noise_epsilon = vars(self.layers[0].mlp.gate).get( + "noise_epsilon", 1e-2 + ) + + self.config.calculator_type = vars(self.layers[0].mlp).get( + "calculator_type", "UniversalCalculator" + ) + self.config.multiply_gate_scores = vars(self.layers[0].mlp.calculator).get( + "multiply_gate_scores", True + ) + self.config.score_scale_factor = [ + vars(self.layers[i].mlp.calculator).get("score_scale_factor", 1.0) + for i in range(self.config.num_hidden_layers) + ] + self.config.drop_tokens = vars(self.layers[0].mlp.calculator).get( + "drop_tokens", True + ) + self.config.dropped_padding = vars(self.layers[0].mlp.calculator).get( + "dropped_padding", "zero" + ) + self.config.capacity_factor = vars(self.layers[0].mlp.calculator).get( + "capacity_factor", 1.25 + ) + + def set_moe_num_selects(self, num_selects): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_num_selects(num_selects) + + def set_moe_gate_use_softmax(self, use_softmax): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_gate_use_softmax(use_softmax) + + def set_moe_gate_use_balance(self, use_balance): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_gate_use_balance(use_balance) + + def set_moe_gate_balance_loss_weight(self, balance_loss_weight): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_gate_balance_loss_weight(balance_loss_weight) + + def set_moe_gate_add_noise(self, add_noise): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_gate_add_noise(add_noise) + + def set_moe_gate_noise_epsilon(self, noise_epsilon): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_gate_noise_epsilon(noise_epsilon) + + def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) + + def set_moe_calculator_score_scale_factor( + self, score_scale_factor, layer_index=None + ): + if layer_index is None: + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_calculator_score_scale_factor(score_scale_factor) + else: + self.layers[layer_index].set_moe_calculator_score_scale_factor( + score_scale_factor + ) + + def set_moe_calculator_drop_tokens(self, drop_tokens): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_calculator_drop_tokens(drop_tokens) + + def set_moe_calculator_dropped_padding(self, dropped_padding): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_calculator_dropped_padding(dropped_padding) + + def set_moe_calculator_capacity_factor(self, capacity_factor): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.set_moe_calculator_capacity_factor(capacity_factor) + + def reset_gate_network(self): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.reset_gate_network() + + def reset_experts(self): + for idx, decoder_layer in enumerate(self.layers): + decoder_layer.reset_experts() + + +class LlamaMoEForCausalLM(LlamaMoEPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaMoEModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseMoEModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if outputs.balance_loss is not None and outputs.balance_loss > 0: + loss += outputs.balance_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoECausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + num_dropped_tokens=outputs.num_dropped_tokens, + balance_loss=outputs.balance_loss, + gate_load=outputs.gate_load, + gate_importance=outputs.gate_importance, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + def update_config(self): + self.model.update_config() + + def set_moe_num_selects(self, num_selects): + self.model.set_moe_num_selects(num_selects) + + def set_moe_gate_use_softmax(self, use_softmax): + self.model.set_moe_gate_use_softmax(use_softmax) + + def set_moe_gate_use_balance(self, use_balance): + self.model.set_moe_gate_use_balance(use_balance) + + def set_moe_gate_balance_loss_weight(self, balance_loss_weight): + self.model.set_moe_gate_balance_loss_weight(balance_loss_weight) + + def set_moe_gate_add_noise(self, add_noise): + self.model.set_moe_gate_add_noise(add_noise) + + def set_moe_gate_noise_epsilon(self, noise_epsilon): + self.model.set_moe_gate_noise_epsilon(noise_epsilon) + + def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): + self.model.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) + + def set_moe_calculator_score_scale_factor( + self, score_scale_factor, layer_index=None + ): + self.model.set_moe_calculator_score_scale_factor( + score_scale_factor, layer_index=layer_index + ) + + def set_moe_calculator_drop_tokens(self, drop_tokens): + self.model.set_moe_calculator_drop_tokens(drop_tokens) + + def set_moe_calculator_dropped_padding(self, dropped_padding): + self.model.set_moe_calculator_dropped_padding(dropped_padding) + + def set_moe_calculator_capacity_factor(self, capacity_factor): + self.model.set_moe_calculator_capacity_factor(capacity_factor) + + def reset_gate_network(self): + self.model.reset_gate_network() + + def reset_experts(self): + self.model.reset_experts() diff --git a/smoe/models/mistral/__init__.py b/smoe/models/mistral/__init__.py new file mode 100644 index 0000000..90bffae --- /dev/null +++ b/smoe/models/mistral/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + +_import_structure = { + "configuration_mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mistral"] = [ + "MistralForCausalLM", + "MistralModel", + "MistralPreTrainedModel", + "MistralForSequenceClassification", + ] + + +if TYPE_CHECKING: + from .configuration_mistral import ( + MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, + MistralConfig, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mistral import ( + MistralForCausalLM, + MistralForSequenceClassification, + MistralModel, + MistralPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__ + ) diff --git a/smoe/models/mistral/configuration_mistral.py b/smoe/models/mistral/configuration_mistral.py new file mode 100644 index 0000000..3c16d6e --- /dev/null +++ b/smoe/models/mistral/configuration_mistral.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mistral model configuration""" + +import copy +from typing import Any, Dict + +from transformers import __version__ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", + "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", +} + + +def recursive_diff_dict(dict_a, dict_b, config_obj=None): + """ + Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the + values from `dict_a` that are different from values in `dict_b`. + """ + diff = {} + default = config_obj.__class__().to_dict() if config_obj is not None else {} + for key, value in dict_a.items(): + obj_value = getattr(config_obj, str(key), None) + if ( + isinstance(obj_value, PretrainedConfig) + and key in dict_b + and isinstance(dict_b[key], dict) + ): + diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value) + if len(diff_value) > 0: + diff[key] = diff_value + elif ( + key not in dict_b + or value != dict_b[key] + or key not in default + or value != default[key] + ): + diff[key] = value + return diff + + +class MistralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an + Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. + + [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) + [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MistralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mistral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import MistralModel, MistralConfig + + >>> # Initializing a Mistral 7B style configuration + >>> configuration = MistralConfig() + + >>> # Initializing a model from the Mistral 7B style configuration + >>> model = MistralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + # Attention implementation to use, if relevant. + self._attn_implementation_internal = kwargs.pop("attn_implementation", None) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] + + # Transformers version when serializing the model + output["transformers_version"] = __version__ + + for key, value in output.items(): + # Deal with nested configs like CLIP + if isinstance(value, PretrainedConfig): + value = value.to_dict() + del value["transformers_version"] + + output[key] = value + + if hasattr(self, "quantization_config"): + output["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(output) + + return output + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = ( + self.__class__().to_dict() if not self.is_composition else {} + ) + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) + ): + # For nested configs we need to clean the diff recursively + diff = recursive_diff_dict( + value, class_config_dict[key], config_obj=getattr(self, key, None) + ) + if "model_type" in value: + # Needs to be set even if it's not in the diff + diff["model_type"] = value["model_type"] + if len(diff) > 0: + serializable_config_dict[key] = diff + elif ( + key not in default_config_dict + or key == "transformers_version" + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + if hasattr(self, "quantization_config"): + serializable_config_dict["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = serializable_config_dict.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(serializable_config_dict) + + if "_attn_implementation_internal" in serializable_config_dict: + del serializable_config_dict["_attn_implementation_internal"] + + return serializable_config_dict diff --git a/smoe/models/mistral/modeling_mistral.py b/smoe/models/mistral/modeling_mistral.py new file mode 100644 index 0000000..395c68a --- /dev/null +++ b/smoe/models/mistral/modeling_mistral.py @@ -0,0 +1,1488 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import importlib +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_available, + logging, + replace_return_docstrings, +) + +from smoe.utils.cache_utils import Cache, DynamicCache +from smoe.utils.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +from .configuration_mistral import MistralConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + + +def _is_package_available( + pkg_name: str, return_version: bool = False +) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + logger.debug(f"Detected {pkg_name} version {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +def is_flash_attn_2_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn"): + return False + + # Let's add an extra check to see if cuda is available + import torch + + if not torch.cuda.is_available(): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.0.4" + ) + else: + return False + + +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], + dim=-1, + ) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, +} + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/smoe/utils/param_estimation.py b/smoe/utils/param_estimation.py index 7f4a7ed..fd6cac3 100644 --- a/smoe/utils/param_estimation.py +++ b/smoe/utils/param_estimation.py @@ -115,6 +115,8 @@ def normal_moe_param( print("7B", res_7B) res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 2) print("7B 2/16", res_7B) + res_7B = estimate_moe_param(32000, 4096, 32, 11008, 8, 2) + print("7B 2/8", res_7B) res_7B = estimate_moe_param(32000, 4096, 32, 11008, 16, 1) print("7B 1/16", res_7B) res_7B = estimate_moe_param(32000, 2560, 32, 11008, 8, 2) From 01eac97538ae8d2ea858565af2c3f987c1a93163 Mon Sep 17 00:00:00 2001 From: zhutong Date: Sat, 23 Dec 2023 22:07:07 +0800 Subject: [PATCH 09/12] update readme --- README.md | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index c0dc348..bd801c1 100644 --- a/README.md +++ b/README.md @@ -23,19 +23,6 @@ We build LLaMA-MoE with the following two steps: | **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | - -| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | -| :------------------------------------------------------------------------------------ | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :-----: | -| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | 50.3 | -| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | 51.5 | -| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | 53.7 | -| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | 55.6 | -| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | 56.4 | -| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | 55.5 | -| **LLaMA-MoE-3.5B (4/16)** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | 57.7 | -| **LLaMA-MoE-3.5B (2/8)** | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | 57.6 | - -

🚀 QuickStart

```python @@ -57,6 +44,20 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) # Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three ``` +

📊 Model Performance

+ +| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | +| :------------------------------------------------------------------------------------ | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :------: | +| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | 50.3 | +| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | 51.5 | +| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | 53.7 | +| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | 55.6 | +| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | 56.4 | +| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | 55.5 | +| **LLaMA-MoE-3.5B (4/16)** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | **57.7** | +| **LLaMA-MoE-3.5B (2/8)** | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | 57.6 | + +

🚧 Expert Construction

- Neuron-Independent From 1f945d1c2393d113736d74c7447473dbdc7cdbd8 Mon Sep 17 00:00:00 2001 From: zhutong Date: Sun, 24 Dec 2023 15:58:12 +0800 Subject: [PATCH 10/12] update features and descriptions --- .gitattributes | 1 + README.md | 58 +++++++++++++++++++++++++++------------ docs/imgs/MoE-Routing.gif | 3 ++ 3 files changed, 44 insertions(+), 18 deletions(-) create mode 100644 docs/imgs/MoE-Routing.gif diff --git a/.gitattributes b/.gitattributes index 8dc584d..3579f7b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ docs/imgs/title-favicon.png filter=lfs diff=lfs merge=lfs -text +docs/imgs/MoE-Routing.gif filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index bd801c1..d6adaa3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training

LLaMA-MoE favicon
- 📢 A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!! + 📢 A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!!
🤗 Model Weights | 📃 Technical Report | 🚀 Quick Start
⚙️ Installation Guide | 🚧 Expert Construction | 🚅 Continual Pre-training | 💎 Evaluation @@ -10,17 +10,33 @@

🎉 Introduction

-LLaMA-MoE is a series of Mixture-of-Expert (MoE) models based on [LLaMA](https://github.com/facebookresearch/llama). +LLaMA-MoE is a series of open-sourced Mixture-of-Expert (MoE) models based on [LLaMA](https://github.com/facebookresearch/llama) and [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama). We build LLaMA-MoE with the following two steps: 1. Partition LLaMA's FFNs into sparse experts and insert top-K gate for each layer of experts. 2. Continually pre-train the initialized MoE model with an optimized data sampling weights from [Sheared LLaMA](https://arxiv.org/abs/2310.06694) and filtered datasets from [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama). - -| Model | \#Activated Experts | \#Experts | \#Activated Params | Links | -| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: | -| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | -| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | -| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | +![MoE Routing](./docs/imgs/MoE-Routing.gif) + +

🔥 Features

+ +1. **Lightweight Models**: The total number of model parameters is only 6.7B, which is easy for deployment and research usage. +2. **Multiple Expert Construction Methods**: + 1. Neuron-Independent: Random, Clustering, Co-activation Graph, Gradient ([Zhang et al., 2022](http://arxiv.org/abs/2110.01786), [Zuo et al., 2022](http://arxiv.org/abs/2204.07675)) + 2. Neuron-Sharing: Inner, Inter (residual) +3. **Multiple MoE Gating Strategies**: + 1. TopK Noisy Gate ([Shazeer et al., 2017](http://arxiv.org/abs/1701.06538)) + 2. Switch Gating ([Fedus et al., 2022](http://arxiv.org/abs/2101.03961)) +4. **Fast Continual Pre-training**: + 1. FlashAttention-v2 integrated ([Dao, 2023](https://github.com/Dao-AILab/flash-attention)) + 2. Fast streaming dataset loading +5. **Abundant Monitor Items**: + 1. Gate load, gate importance + 2. Loss on steps, loss on tokens, balance loss + 3. TGS (tokens/GPU/second), MFU (model FLOPs utilization) + 4. Other visualization utilities +6. **Dynamic Weight Sampling**: + 1. Self-defined static sampling weights + 2. Sheared LLaMA's dynamic batch loading ([Xia et al., 2023](http://arxiv.org/abs/2310.06694))

🚀 QuickStart

@@ -46,16 +62,22 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))

📊 Model Performance

-| Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average | -| :------------------------------------------------------------------------------------ | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :------: | -| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | 50.3 | -| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | 51.5 | -| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | 53.7 | -| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | 55.6 | -| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | 56.4 | -| **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | 55.5 | -| **LLaMA-MoE-3.5B (4/16)** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | **57.7** | -| **LLaMA-MoE-3.5B (2/8)** | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | 57.6 | +| Model | \#Activated Experts | \#Experts | \#Activated Params | Links | +| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: | +| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | +| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | +| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | + +| Model | Average | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | +| :------------------------------------------------------------------------------------ | :------: | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | +| [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 50.3 | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | +| [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 51.5 | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | +| [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 53.7 | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | +| [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 55.6 | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | +| [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 56.4 | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | +| **LLaMA-MoE-3.0B** | 55.5 | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | +| **LLaMA-MoE-3.5B (4/16)** | **57.7** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | +| **LLaMA-MoE-3.5B (2/8)** | 57.6 | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 |

🚧 Expert Construction

diff --git a/docs/imgs/MoE-Routing.gif b/docs/imgs/MoE-Routing.gif new file mode 100644 index 0000000..6c8e2d6 --- /dev/null +++ b/docs/imgs/MoE-Routing.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0a31562b85a1ad8d7e62c58dcb3f60bdf19ed70b4becf3f3b0ae51ae1ec19bd +size 608200 From fe0ee64d98a8f267bded2b33905baa69f3ad15d7 Mon Sep 17 00:00:00 2001 From: zhutong Date: Sun, 24 Dec 2023 16:03:42 +0800 Subject: [PATCH 11/12] modify words --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d6adaa3..de53ab2 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ We build LLaMA-MoE with the following two steps:

🔥 Features

-1. **Lightweight Models**: The total number of model parameters is only 6.7B, which is easy for deployment and research usage. +1. **Lightweight Models**: The total number of model parameters is only 6.7B, which is friendly for deployment and research usage. 2. **Multiple Expert Construction Methods**: 1. Neuron-Independent: Random, Clustering, Co-activation Graph, Gradient ([Zhang et al., 2022](http://arxiv.org/abs/2110.01786), [Zuo et al., 2022](http://arxiv.org/abs/2204.07675)) 2. Neuron-Sharing: Inner, Inter (residual) From 530af610990e51548a0486abd20d541dd1830a79 Mon Sep 17 00:00:00 2001 From: zhutong Date: Sun, 24 Dec 2023 16:10:19 +0800 Subject: [PATCH 12/12] update scripts --- .vscode/launch.json | 4 ++-- docs/continual_pretraining/README.md | 4 ++-- .../baseline_112gpus_sheared_llama_portion_fluency_sf4.sh | 1 - .../baseline_112gpus_sheared_llama_portion_fluency_sf8.sh | 1 - .../baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh | 1 - scripts/cpt/dynamic_data_selection/baseline_112gpus.sh | 1 - .../dynamic_data_selection/baseline_112gpus_linear_gate.sh | 1 - .../cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh | 1 - .../baseline_112gpus_sheared_llama_portion.sh | 1 - .../baseline_112gpus_sheared_llama_portion_fluency.sh | 1 - ...line_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh | 1 - .../baseline_112gpus_sheared_llama_portion_no_ad.sh | 1 - scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh | 1 - .../cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh | 1 - scripts/cpt/fpt.sh | 1 - scripts/cpt/fpt_resume.sh | 1 - scripts/cpt/fpt_switch.sh | 4 ---- scripts/cpt/gate_loss.sh | 1 - scripts/moefication/select/run_select.sh | 1 - scripts/test/test_conn.sh | 1 - scripts/tokenize/clustering.sh | 2 +- 21 files changed, 5 insertions(+), 26 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 96f81b0..2218cf6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,7 +16,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-12", + "host": "x.x.x.x", "port": 5678 }, "pathMappings": [ @@ -28,4 +28,4 @@ "justMyCode": false } ] -} \ No newline at end of file +} diff --git a/docs/continual_pretraining/README.md b/docs/continual_pretraining/README.md index 6e5e441..05bbc17 100644 --- a/docs/continual_pretraining/README.md +++ b/docs/continual_pretraining/README.md @@ -90,7 +90,7 @@ Based on the above information, the expected time could be calculated. The tensorboard `logging_dir` could be found at `outputs/-/runs/`. -For example, if my job name is `cpt-moe-fpt-bs16-48gpus` in the sbatch file, the tensorboard could be started from that by: `tensorboard --logdir outputs/cpt-moe-fpt-bs16-48gpus-1535835/runs/Jul31_14-12-00_SH-IDCA1404-10-140-54-100` . +For example, if my job name is `cpt-moe-fpt-bs16-48gpus` in the sbatch file, the tensorboard could be started from that by: `tensorboard --logdir outputs/cpt-moe-fpt-bs16-48gpus-1535835/runs/Jul31_14-12-00` . For multiple tasks with different logging directories, you could run the following command: @@ -101,5 +101,5 @@ $ tensorboard --logdir_spec short_name:dir1,short_name2:dir2 --port 8001 Here, the `short_name` is an abbreviation for your task, and the port number could be changed manually if there's a port conflict. e.g. ```bash -$ tensorboard --logdir_spec moe_from_scratch:outputs/cpt-llama-moe-scratch-lora-bs16-1476932/runs/Jul26_21-53-42_SH-IDCA1404-10-140-54-121,moe_lora:outputs/cpt-llama-lora-bs16-1476918/runs/Jul26_21-31-09_SH-IDCA1404-10-140-54-122 --port 8001 +$ tensorboard --logdir_spec moe_from_scratch:outputs/cpt-llama-moe-scratch-lora-bs16-1476932/runs/Jul26_21-53-42,moe_lora:outputs/cpt-llama-lora-bs16-1476918/runs/Jul26_21-31-09 --port 8001 ``` diff --git a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh index ff4c760..dad85c1 100644 --- a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh +++ b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf4.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 # reserved spot diff --git a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh index 51a90fd..4ac1316 100644 --- a/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh +++ b/scripts/cpt/16_2/baseline_112gpus_sheared_llama_portion_fluency_sf8.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 # reserved spot diff --git a/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh b/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh index 177c55a..861e458 100644 --- a/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh +++ b/scripts/cpt/8_2/baseline_112gpus_8_2_sheared_llama_portion_fluency_sf4.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh index 1929a77..c8f2768 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh index 03c65e0..2426ecb 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_linear_gate.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh index 025a72a..f3231e3 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh index fc60f05..9d21b65 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh index ae8a38c..abaaaf5 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_fluency.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36,SH-IDCA1404-10-140-54-24 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh index 3de43f8..619ff3e 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_gate_balance_loss0.1.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh index 4d93400..129bfc1 100644 --- a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion_no_ad.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh index a24d857..b6ffcb1 100644 --- a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh index fe2af8f..1633f35 100644 --- a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=14 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -x SH-IDCA1404-10-140-54-36 # reserved spot diff --git a/scripts/cpt/fpt.sh b/scripts/cpt/fpt.sh index 056ff02..5a124f3 100644 --- a/scripts/cpt/fpt.sh +++ b/scripts/cpt/fpt.sh @@ -8,7 +8,6 @@ #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116,SH-IDCA1404-10-140-54-70 #SBATCH --nodes=1 #SBATCH --gres=gpu:8 diff --git a/scripts/cpt/fpt_resume.sh b/scripts/cpt/fpt_resume.sh index b5919b5..e8cad6f 100644 --- a/scripts/cpt/fpt_resume.sh +++ b/scripts/cpt/fpt_resume.sh @@ -8,7 +8,6 @@ #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=64 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116 #SBATCH --nodes=7 #SBATCH --gres=gpu:8 diff --git a/scripts/cpt/fpt_switch.sh b/scripts/cpt/fpt_switch.sh index eadb245..8c23c26 100644 --- a/scripts/cpt/fpt_switch.sh +++ b/scripts/cpt/fpt_switch.sh @@ -10,7 +10,6 @@ #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116,SH-IDCA1404-10-140-54-15 #SBATCH --nodes=7 #SBATCH --gres=gpu:8 @@ -127,6 +126,3 @@ export LOGLEVEL=INFO --report_to none \ --log_level info } - -# srun -p MoE -n1 -N1 -w SH-IDCA1404-10-140-54-43 scontrol listpids -# srun -p MoE -n1 -N1 -w SH-IDCA1404-10-140-54-43 py-spy dump --pid 118340 diff --git a/scripts/cpt/gate_loss.sh b/scripts/cpt/gate_loss.sh index 692c894..d5440f2 100644 --- a/scripts/cpt/gate_loss.sh +++ b/scripts/cpt/gate_loss.sh @@ -8,7 +8,6 @@ #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116 #SBATCH --nodes=1 #SBATCH --gres=gpu:8 diff --git a/scripts/moefication/select/run_select.sh b/scripts/moefication/select/run_select.sh index 8548b0e..9d240d7 100644 --- a/scripts/moefication/select/run_select.sh +++ b/scripts/moefication/select/run_select.sh @@ -25,7 +25,6 @@ save_path=${data_path}/moefication_results/select/${split_type} save_visualization_path=/mnt/petrelfs/dongdaize.d/workspace/train-moe/visualization/expert/${split_type}-${select_type}/${llama_size}-${num_experts}Select${num_selects}-${proj_type} #node=108 -# -w SH-IDCA1404-10-140-54-${node} \ gpus=1 cpus=16 for specify_layer in "0 1" "2 3" "4 5" "6 7" "8 9" "10 11" "12 13" "14 15" "16 17" "18 19" "20 21" "22 23" "24 25" "26 27" "28 29" "30 31"; do # 并行启用任务 diff --git a/scripts/test/test_conn.sh b/scripts/test/test_conn.sh index dcf13bc..bbc5fd8 100644 --- a/scripts/test/test_conn.sh +++ b/scripts/test/test_conn.sh @@ -12,7 +12,6 @@ #SBATCH --nodes=3 #SBATCH --gres=gpu:8 #SBATCH --quotatype=reserved -#SBATCH -w SH-IDCA1404-10-140-54-11,SH-IDCA1404-10-140-54-36 export OMP_NUM_THREADS=4 diff --git a/scripts/tokenize/clustering.sh b/scripts/tokenize/clustering.sh index dbdfae8..26f648d 100644 --- a/scripts/tokenize/clustering.sh +++ b/scripts/tokenize/clustering.sh @@ -14,7 +14,7 @@ mkdir -p $logs_dir for data_type in $(ls $data_dir) do log_path=logs/tokenize_${data_type}_32clusters.log - nohup srun -p MoE -N1 -n1 --cpus-per-task=32 -x "SH-IDCA1404-10-140-54-[12,18,33,38,41,43,63,70-71,74,83,85]" \ + nohup srun -p MoE -N1 -n1 --cpus-per-task=32 \ python -m smoe.utils.tokenize \ -f jsonl \ -t $tokenizer_dir \