diff --git a/.gitignore b/.gitignore index a2bf5e6..5c2bc94 100644 --- a/.gitignore +++ b/.gitignore @@ -167,7 +167,7 @@ outputs/ /visualization/ results/analysis/cluster_*.png results/expert_load_vis -results/analysis_clustering7 +results/analysis_clustering* results/gate_loss_100b results/RandomSplit-l2_norm-llama_7B-16Select4-up_proj results/gate_loss_original_clustering_model diff --git a/.vscode/launch.json b/.vscode/launch.json index 0b4f69f..6010699 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-115", + "host": "SH-IDCA1404-10-140-54-122", "port": 5678 }, "pathMappings": [ diff --git a/scripts/cpt/fpt.sh b/scripts/cpt/fpt.sh index 798d40b..056ff02 100644 --- a/scripts/cpt/fpt.sh +++ b/scripts/cpt/fpt.sh @@ -1,6 +1,6 @@ #!/usr/bin/bash -#SBATCH --job-name=cpt-16select4-64gpus +#SBATCH --job-name=cpt-7b-test #SBATCH --output=logs/%x-%j.log #SBATCH --error=logs/%x-%j.log @@ -8,26 +8,30 @@ #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116 +#SBATCH -x SH-IDCA1404-10-140-54-116,SH-IDCA1404-10-140-54-70 -#SBATCH --nodes=7 +#SBATCH --nodes=1 #SBATCH --gres=gpu:8 source ~/anaconda3/bin/activate smoe -num_nodes=7 # should match with --nodes -num_gpu_per_node=8 # should match with --gres - -# #cpu/#num_gpu_per_node -export OMP_NUM_THREADS=4 -export LOGLEVEL=INFO -# export NCCL_DEBUG=INFO -# export TORCH_DISTRIBUTED_DEBUG=DETAIL -# export TORCH_SHOW_CPP_STACKTRACES=1 -# export CUDA_LAUNCH_BLOCKING=1 { + num_nodes=1 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=16 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + comment="exp purpose" + # model_type="llama" + # pretrained_model="outputs/llama1_7B_random" # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B model_type="llama_moe" pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4-l2_norm_bak @@ -40,13 +44,14 @@ export LOGLEVEL=INFO # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed - lr=3e-4 + lr=1e-4 final_lr_portion=0.1 - per_device_train_batch_size=8 + per_device_train_batch_size=16 per_device_eval_batch_size=1 - gradient_accumulation_steps=4 + gradient_accumulation_steps=2 block_size=2048 num_tokens="1*10^11" + seed=1227 deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) @@ -61,8 +66,11 @@ export LOGLEVEL=INFO data_cache=resources/cache output_dir=outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID mkdir -p $output_dir - scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/.env + echo $comment > $output_dir/comment.txt nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIS ) ) nodes_array=($nodes) @@ -78,7 +86,7 @@ export LOGLEVEL=INFO --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ --model_type ${model_type} \ @@ -89,7 +97,7 @@ export LOGLEVEL=INFO --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ - --seed $RANDOM \ + --seed ${seed} \ --bf16 \ --num_train_epochs 1 \ --final_lr_portion ${final_lr_portion} \ @@ -102,8 +110,6 @@ export LOGLEVEL=INFO --warmup_steps 2000 \ --max_steps ${max_steps} \ --max_train_samples ${max_train_samples} \ - --logging_strategy steps \ - --logging_steps 10 \ --save_strategy steps \ --save_total_limit 2 \ --save_steps 1000 \ @@ -113,12 +119,16 @@ export LOGLEVEL=INFO --output_dir ${output_dir} \ --overwrite_output_dir \ --ddp_timeout 30000 \ - --logging_first_step True \ - --torch_dtype bfloat16 \ --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ --gradient_checkpointing \ - --report_to none \ - --log_level info + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 10 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none } #SBATCH --job-name=cpt-moe-fpt-test_lr_change #改动前:--logging_steps 10 \ diff --git a/scripts/cpt/fpt_13b.sh b/scripts/cpt/fpt_13b.sh index 5561804..cab4512 100644 --- a/scripts/cpt/fpt_13b.sh +++ b/scripts/cpt/fpt_13b.sh @@ -1,59 +1,80 @@ #!/usr/bin/bash -#SBATCH --job-name=cpt-moe-fpt-13b-64gpus-bs8_4-task_test +#SBATCH --job-name=cpt-13b-test #SBATCH --output=logs/%x-%j.log #SBATCH --error=logs/%x-%j.log +##SBATCH --output=logs/%x.log +##SBATCH --error=logs/%x.log #SBATCH --partition=MoE #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --mem=0 -#SBATCH -x SH-IDCA1404-10-140-54-116 -#SBATCH --time=8:00:00 -#SBATCH --nodes=8 +#SBATCH --nodes=2 #SBATCH --gres=gpu:8 +#SBATCH --quotatype=auto +##SBATCH --time=5:00:00 source ~/anaconda3/bin/activate smoe { - num_nodes=8 # should match with --nodes + num_nodes=2 # should match with --nodes num_gpu_per_node=8 # should match with --gres # #cpu/#num_gpu_per_node - export OMP_NUM_THREADS=4 + export OMP_NUM_THREADS=16 export LOGLEVEL=INFO # export NCCL_DEBUG=INFO # export TORCH_DISTRIBUTED_DEBUG=DETAIL # export TORCH_SHOW_CPP_STACKTRACES=1 # export CUDA_LAUNCH_BLOCKING=1 - lr=3e-4 + # comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4, expert weight re-scale" + comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4" + # comment="random initialized llama1-7B" + # comment="random initialized llama1-13B" + # comment="7B, expert 4/16, noisy gate, gradient shared neurons, w/o residual, w/o weight re-scale, lr2e-4" + # comment="3B MoE, debug" # model_type="llama" - # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B - # model_type="llama_moe" - # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4-l2_norm + # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/llama_13B" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random model_type="llama_moe" - pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Clustering-l2/llama_13B-16Select4-up_proj" + # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_3B-8Select2-4320Neurons-Share" + # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-16Select4-688Neurons-Share" + pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons-Share" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4-l2_norm + # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Clustering-l2/llama_13B-16Select4-up_proj" # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj - # pretrained_model=$1 - echo "==================> $pretrained_model <==================" # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Clustering-l2/llama_13B-16Select4-up_proj # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Graph-l2_norm/llama_13B-16Select4-up_proj # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama_13B-16Select4-up_proj + # pretrained_model=$1 + echo "==================> $pretrained_model <==================" # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B - # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj + # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax-copy/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj + # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_13B + # tokenizer_path="/mnt/petrelfs/share_data/quxiaoye/models/llama_3B" + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed + # dataset_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized + lr=2e-4 + final_lr_portion=0.1 per_device_train_batch_size=8 per_device_eval_batch_size=1 gradient_accumulation_steps=4 + num_tokens="3*10^11" + seed=1227 block_size=2048 - max_steps=$(echo "10^11 / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) - max_train_samples=$(echo "10^11 / $block_size" | bc) + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / $block_size" | bc) echo "max_steps: $max_steps" echo "max_train_samples: $max_train_samples" global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) @@ -63,8 +84,13 @@ source ~/anaconda3/bin/activate smoe data_cache=resources/cache output_dir=outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + # output_dir=/mnt/petrelfs/share_data/quxiaoye/models/tzhu_model_bak/cpt-13b-16gpus-lr2e-4 + mkdir -p $output_dir echo "output_dir: $output_dir" - deepspeed_config_file=conf/deepspeed/bf16_zero2_default.json + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/.env + echo $comment > $output_dir/comment.txt nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIS ) ) nodes_array=($nodes) @@ -73,6 +99,7 @@ source ~/anaconda3/bin/activate smoe echo "Node: $head_node" echo "Node IP: $head_node_ip" + # --resume_from_checkpoint /mnt/petrelfs/share_data/quxiaoye/models/tzhu_model_bak/cpt-13b-16gpus-lr2e-4/checkpoint-2000 \ srun torchrun \ --nnodes ${num_nodes} \ --nproc_per_node ${num_gpu_per_node} \ @@ -80,7 +107,7 @@ source ~/anaconda3/bin/activate smoe --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ --model_type ${model_type} \ @@ -91,10 +118,10 @@ source ~/anaconda3/bin/activate smoe --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ - --seed $RANDOM \ + --seed ${seed} \ --bf16 \ --num_train_epochs 1 \ - --final_lr_portion 0.1 \ + --final_lr_portion ${final_lr_portion} \ --optim adamw_torch \ --adam_beta1 0.9 \ --adam_beta2 0.95 \ @@ -103,9 +130,7 @@ source ~/anaconda3/bin/activate smoe --max_grad_norm 1.0 \ --warmup_steps 2000 \ --max_steps ${max_steps} \ - --max_train_samples 48828125 \ - --logging_strategy steps \ - --logging_steps 10 \ + --max_train_samples ${max_train_samples} \ --save_strategy steps \ --save_total_limit 1 \ --save_steps 1000 \ @@ -115,10 +140,18 @@ source ~/anaconda3/bin/activate smoe --output_dir ${output_dir} \ --overwrite_output_dir \ --ddp_timeout 30000 \ - --logging_first_step True \ - --torch_dtype bfloat16 \ --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ --gradient_checkpointing \ - --report_to none \ - --log_level info + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 10 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects 4 \ + --report_to none + } diff --git a/scripts/cpt/fpt_resume.sh b/scripts/cpt/fpt_resume.sh index 01785c3..b5919b5 100644 --- a/scripts/cpt/fpt_resume.sh +++ b/scripts/cpt/fpt_resume.sh @@ -76,7 +76,7 @@ export LOGLEVEL=INFO --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --ignore_data_skip \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ diff --git a/scripts/cpt/fpt_switch.sh b/scripts/cpt/fpt_switch.sh index 0fb47f6..eadb245 100644 --- a/scripts/cpt/fpt_switch.sh +++ b/scripts/cpt/fpt_switch.sh @@ -82,7 +82,7 @@ export LOGLEVEL=INFO --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ --model_type ${model_type} \ diff --git a/scripts/cpt/fpt_test_lr.sh b/scripts/cpt/fpt_test_lr.sh index cdf37d8..5a07f17 100644 --- a/scripts/cpt/fpt_test_lr.sh +++ b/scripts/cpt/fpt_test_lr.sh @@ -69,7 +69,7 @@ export LOGLEVEL=INFO --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ --model_type ${model_type} \ diff --git a/scripts/cpt/gate_loss.sh b/scripts/cpt/gate_loss.sh index 941d113..692c894 100644 --- a/scripts/cpt/gate_loss.sh +++ b/scripts/cpt/gate_loss.sh @@ -75,7 +75,7 @@ export GATE_LOSS_RESULTS_DIR="results/RandomSplit-l2_norm-llama_7B-16Select4-up_ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_fpt.py \ + smoe/entrypoint/cpt/cpt_fpt.py \ --ignore_data_skip \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ diff --git a/scripts/cpt/lora.sh b/scripts/cpt/lora.sh index 56994c8..062d904 100644 --- a/scripts/cpt/lora.sh +++ b/scripts/cpt/lora.sh @@ -67,7 +67,7 @@ srun torchrun \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node:29518 \ - smoe/entrypoint/cpt_lora.py \ + smoe/entrypoint/cpt/cpt_lora.py \ --deepspeed ${deepspeed_config_file} \ --model_name_or_path ${pretrained_model} \ --model_type ${model_type} \ diff --git a/scripts/tokenize/clustering.sh b/scripts/tokenize/clustering.sh index 6176a83..dbdfae8 100644 --- a/scripts/tokenize/clustering.sh +++ b/scripts/tokenize/clustering.sh @@ -3,8 +3,8 @@ set -vx tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B -data_dir=/mnt/petrelfs/zhutong/smoe/resources/clustering_samples -out_dir=/mnt/petrelfs/share_data/quxiaoye/data/16clusters +data_dir=/mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32 +out_dir=/mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32_tokenized logs_dir=logs mkdir -p $out_dir @@ -13,13 +13,13 @@ mkdir -p $logs_dir # for loop in: en_arxiv, en_book, en_c4, en_cc, en_stack, en_wikipedia, github for data_type in $(ls $data_dir) do - log_path=logs/tokenize_$data_type.log - nohup srun -p MoE -N1 -n1 --cpus-per-task=32 \ + log_path=logs/tokenize_${data_type}_32clusters.log + nohup srun -p MoE -N1 -n1 --cpus-per-task=32 -x "SH-IDCA1404-10-140-54-[12,18,33,38,41,43,63,70-71,74,83,85]" \ python -m smoe.utils.tokenize \ -f jsonl \ -t $tokenizer_dir \ -i $data_dir/$data_type \ -o $out_dir/$data_type \ - 1>$logs_dir/tokenize_$data_type.log 2>&1 & - echo "$data_type > $logs_dir/tokenize_$data_type.log" + 1>${log_path} 2>&1 & + echo "$data_type > $log_path" done diff --git a/scripts/tokenize/lines.sh b/scripts/tokenize/lines.sh new file mode 100644 index 0000000..c876cbc --- /dev/null +++ b/scripts/tokenize/lines.sh @@ -0,0 +1,7 @@ +# srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_8/3.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/8clusters/3.jsonl +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/5.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/5.jsonl 1>logs/tokenize_32_5.log 2>&1 & +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/7.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/7.jsonl 1>logs/tokenize_32_7.log 2>&1 & +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/8.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/8.jsonl 1>logs/tokenize_32_8.log 2>&1 & +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/12.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/12.jsonl 1>logs/tokenize_32_12.log 2>&1 & +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/26.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/26.jsonl 1>logs/tokenize_32_26.log 2>&1 & +nohup srun -p MoE -N1 -n1 --cpus-per-task=8 python -m smoe.utils.tokenize -f jsonl -t /mnt/petrelfs/share_data/quxiaoye/models/llama_7B -i /mnt/petrelfs/zhutong/smoe/resources/clustering_samples_32/31.jsonl -o /mnt/petrelfs/share_data/quxiaoye/data/32clusters/31.jsonl 1>logs/tokenize_32_31.log 2>&1 & diff --git a/scripts/tokenize/redpajama.sh b/scripts/tokenize/redpajama.sh index d453585..1c2400b 100644 --- a/scripts/tokenize/redpajama.sh +++ b/scripts/tokenize/redpajama.sh @@ -2,9 +2,14 @@ set -vx -tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B -data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data -out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed +# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B +# data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data +# out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed + +tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_3B +data_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples +out_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized + logs_dir=logs mkdir -p $logs_dir diff --git a/smoe/callbacks/tensorboard.py b/smoe/callbacks/tensorboard.py index 5c35134..6696024 100644 --- a/smoe/callbacks/tensorboard.py +++ b/smoe/callbacks/tensorboard.py @@ -30,11 +30,18 @@ def on_log( tokens = state.global_step * args.num_tokens_per_batch token_loss_key = "train/loss_on_tokens" self.tb_writer.add_scalar(token_loss_key, v, tokens) - elif ( - k == "train/num_dropped_tokens" - and isinstance(v, tuple) - and all(isinstance(n, (int, float)) for n in v) - ): + elif k == "train/balance_loss": + if isinstance(v, torch.Tensor) and hasattr(v, "item"): + _v = v.item() + elif isinstance(v, float): + _v = v + else: + continue + self.tb_writer.add_scalar(k, _v, state.global_step) + elif k == "train/num_dropped_tokens" and isinstance(v, tuple): + # (tensor(1.0), tensor(2.3)) -> [1.0, 2.3] + if all(isinstance(n, torch.Tensor) for n in v): + v = [n.item() for n in v] self.tb_writer.add_scalars( f"{k}/layer", {str(i): n for i, n in enumerate(v)}, @@ -42,25 +49,17 @@ def on_log( ) self.tb_writer.add_scalar(f"{k}/total", sum(v), state.global_step) elif ( - k == "train/gate_load" - or k == "train/gate_importance" - and isinstance(v, tuple) - and all(isinstance(n, list) for n in v) - ): - _v = [torch.tensor(n) for n in v] + k == "train/gate_load" or k == "train/gate_importance" + ) and isinstance(v, tuple): + if not all(isinstance(n, torch.Tensor) for n in v): + v = [torch.tensor(n) for n in v] + # v: (tensor([1.0, 2.3, ... num_experts]), tensor([3.0, 4.5, ... num_experts]), ... num_layers) self.tb_writer.add_scalars( f"{k}/std/layer", - {str(i): n.std().item() for i, n in enumerate(_v)}, + {str(i): n.std().item() for i, n in enumerate(v)}, state.global_step, ) self.tb_writer.add_image( - k, get_heatmap_img_grid_for_tb(_v), state.global_step - ) - else: - logger.warning( - "Trainer is attempting to log a value of " - f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' - "This invocation of Tensorboard's writer.add_scalar() " - "is incorrect so we dropped this attribute." + k, get_heatmap_img_grid_for_tb(v), state.global_step ) self.tb_writer.flush() diff --git a/smoe/data/aggregation.py b/smoe/data/aggregation.py index ba7e49d..390dcad 100644 --- a/smoe/data/aggregation.py +++ b/smoe/data/aggregation.py @@ -1,7 +1,7 @@ from itertools import chain -def group_texts(examples, block_size: int = 1024): +def group_texts(examples: dict, block_size: int = 1024): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} @@ -17,3 +17,59 @@ def group_texts(examples, block_size: int = 1024): } result["labels"] = result["input_ids"].copy() return result + + +def group_instances(examples: list[dict], block_size: int = 2048) -> list[dict]: + """ + Concate examples to a length of block size. + + Args: + examples: a list of dict instances that have multiple keys + block_size: the length of the concatenated examples + """ + + def _concat(examples: list[dict]) -> dict: + """ + Concatenate the values of each key in the examples. + + Args: + examples: a list of dict instances that have multiple keys + """ + concatenated_examples = {} + keys = examples[0].keys() + for k in keys: + concatenated_examples[k] = list(chain(*[e[k] for e in examples])) + if "labels" not in keys and "input_ids" in keys: + concatenated_examples["labels"] = concatenated_examples["input_ids"] + return concatenated_examples + + def _chunk(examples: dict, block_size: int) -> list[dict]: + """ + Split the concatenated examples into chunks of block_size. + + Args: + examples: a dict instance that has multiple keys + block_size: the length of the concatenated examples + """ + total_length = len(examples[list(examples.keys())[0]]) + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in examples.items() + } + return result + + def _decompose(example: dict) -> list[dict]: + """ + Decompose the example into a list of dict instances. + + Args: + example: a dict instance that has multiple keys + """ + num_chunks = len(example[list(example.keys())[0]]) + return [{k: example[k][i] for k in example.keys()} for i in range(num_chunks)] + + concatenated_examples = _concat(examples) + chunk = _chunk(concatenated_examples, block_size) + return _decompose(chunk) diff --git a/smoe/data/streaming.py b/smoe/data/streaming.py new file mode 100644 index 0000000..4a018c6 --- /dev/null +++ b/smoe/data/streaming.py @@ -0,0 +1,323 @@ +""" +References: + - https://github.com/jzhang38/TinyLlama/blob/main/lit_gpt/packed_dataset.py + - https://github.com/jzhang38/TinyLlama/blob/main/pretrain/tinyllama.py + - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py +""" + +import random +from pathlib import Path +from typing import Iterator + +import torch +from torch.utils.data import IterableDataset + +from smoe.data.aggregation import group_instances +from smoe.utils.io import load_jsonlines_iter +from smoe.utils.logging import get_logger +from smoe.utils.random import get_random_string +from smoe.utils.vars import JSONL_DATASET_CACHE_NAME + +logger = get_logger(__file__) + + +class JsonlDataset(IterableDataset): + def __init__( + self, + filepath: str, + cache_dir: str, + uid: str = None, + seed: int = 1227, + buffer_size: int = 32, + num_skip: int = None, + file_start_byte: int = None, + ) -> None: + super().__init__() + + if uid: + self.uid = uid + else: + self.uid = f"{Path(self.filepath).stem}-{get_random_string()}" + self.cache_dir = cache_dir + + self.filepath = filepath + self.seed = seed + self.rng = random.Random(seed) + self.buffer_size = buffer_size + self.num_skip = num_skip + self.file_start_byte = file_start_byte + self.num_yield = 0 + + if self.file_start_byte and self.num_skip: + raise ValueError("Cannot set both `file_start_byte` and `num_skip`") + + self.load_fh = load_jsonlines_iter( + self.filepath, start_from=self.file_start_byte + ) + self.buffer = [] + + def state_dict(self): + return { + "filepath": self.filepath, + "cache_dir": self.cache_dir, + "uid": self.uid, + "seed": self.seed, + "rng": self.rng.getstate(), + "num_skip": self.num_skip, + "file_start_byte": self.file_start_byte, + "buffer_size": self.buffer_size, + "num_yield": self.num_yield, + "load_fh_tell": self.load_fh.tell(), + "buffer": self.buffer, + } + + def save_pretrained(self, output_dir: str): + state_dict = self.state_dict() + name = JSONL_DATASET_CACHE_NAME.format(self.uid) + dump_path = Path(output_dir) / name + torch.save(state_dict, dump_path) + return str(dump_path) + + @classmethod + def from_state_dict(cls, state_dict: dict): + obj = cls( + state_dict["filepath"], + state_dict["cache_dir"], + uid=state_dict["uid"], + seed=state_dict["seed"], + buffer_size=state_dict["buffer_size"], + num_skip=state_dict["num_skip"], + file_start_byte=state_dict["file_start_byte"], + ) + obj.rng.setstate(state_dict["rng"]) + obj.num_yield = state_dict["num_yield"] + obj.buffer = state_dict["buffer"] + return obj + + @classmethod + def from_pretrained(cls, state_dict_filepath: str): + state_dict = torch.load(state_dict_filepath) + return cls.from_state_dict(state_dict) + + def __iter__(self) -> Iterator: + self.buffer = [] + for ins in self.load_fh: + if self.num_skip and self.num_yield < self.num_skip: + self.num_yield += 1 + continue + + if self.buffer_size <= 1: + yield ins + continue + + if len(self.buffer) >= self.buffer_size: + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + yield from self.buffer + self.num_yield += len(self.buffer) + self.buffer.clear() + + self.buffer.append(ins) + + # for the last batch < buffer_size + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + yield from self.buffer + self.num_yield += len(self.buffer) + self.buffer.clear() + + +class WeightedPackedDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + weights: list[float] = None, + seed: int = 1227, + ): + self.datasets = datasets + self.weights = weights + if weights: + assert len(datasets) == len(weights) + self.rng = random.Random(seed) + + def __iter__(self): + while len(self.datasets) > 0: + candidate_ids = list(range(self.datasets)) + choice = self.rng.choices(candidate_ids, weights=self.weights, k=1)[0] + try: + yield next(self.datasets[choice]) + except StopIteration: + self.datasets.pop(choice) + if self.weights: + self.weights.pop(choice) + yield from self + + +class WeightedPackedDatasetBuilder: + def __init__( + self, + filepaths: list[str], + cache_dir: str, + resume: bool = False, + seed: int = 1227, + buffer_size: int = 32, + ) -> None: + self.rng = random.Random(seed) + + self.filepaths = filepaths + self.rng.shuffle(self.filepaths) + self.datasets = [] + + resumed_path_to_state_dict = {} + if resume: + for path in Path(cache_dir).glob(JSONL_DATASET_CACHE_NAME.format("*")): + state_dict = torch.load(path) + resumed_path_to_state_dict[state_dict["filepath"]] = state_dict + + for filepath in self.filepaths: + if filepath in resumed_path_to_state_dict: + state_dict = resumed_path_to_state_dict[filepath] + self.datasets.append(JsonlDataset.from_state_dict(state_dict)) + else: + self.datasets.append( + JsonlDataset( + filepath, + cache_dir, + seed=seed, + buffer_size=buffer_size, + ) + ) + + def __iter__(self) -> Iterator: + for ds in self.datasets: + yield from ds + + +class PackedJsonlDataset(IterableDataset): + def __init__( + self, + data_dir: str, + seed: int = 1227, + buffer_size: int = 200, + block_size: int = 2048, + ) -> None: + super().__init__() + self.rng = random.Random(seed) + self.buffer_size = buffer_size + self.block_size = block_size + + data_dir_path = Path(data_dir) + filepaths = sorted(data_dir_path.glob("**/*.jsonl")) + self.rng.shuffle(filepaths) + + self.filepaths = filepaths + self.buffer = [] + + def __iter__(self) -> Iterator: + self.buffer = [] + for filepath in self.filepaths: + logger.debug(f"Iter over jsonl file: {filepath}") + for ins in load_jsonlines_iter(filepath): + if self.buffer_size <= 1: + yield ins + continue + + if len(self.buffer) >= self.buffer_size: + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + self.buffer_aggregation() + yield from self.buffer + self.buffer.clear() + + self.buffer.append(ins) + + # for the last batch < buffer_size + if len(self.buffer) > 0: + self.rng.shuffle(self.buffer) + self.buffer_aggregation() + yield from self.buffer + self.buffer.clear() + + def buffer_aggregation(self): + if self.block_size > 0 and len(self.buffer) > 0: + results = group_instances(self.buffer, self.block_size) + self.buffer = results + + +class SubDirWeightedPackedJsonlDataset(IterableDataset): + """ + Example: + >>> dataset = SubDirWeightedPackedJsonlDataset( + ... "/mnt/petrelfs/share_data/redpajama/tokenized", + ... weights={ + ... "en_cc": 0.67, + ... "en_c4": 0.15, + ... "github": 0.045, + ... "en_wikipedia": 0.045, + ... "en_book": 0.045, + ... "en_arxiv": 0.025, + ... "en_stack": 0.02, + ... } + ... ) + >>> for ins in dataset: + ... print(ins) + + Inputs: + dataset_dir: folder structure is: + task1 dir: 1.jsonl, 2.jsonl, ... + task2 dir: 1.jsonl, ... + weights: dirname to sampling weight. + e.g. {"task1 dir": 0.3, "task2 dir": 0.7} + """ + + def __init__( + self, + dataset_dir: str, + prob_map: dict[str, float] = None, + seed: int = 1227, + buffer_size: int = 200, + block_size: int = 2048, + ) -> None: + self.rng = random.Random(seed) + self.buffer_size = buffer_size + self.dataset_dir_path = Path(dataset_dir) + + task_types = [p.stem for p in self.dataset_dir_path.glob("*") if p.is_dir()] + + if prob_map is None: + prob_map = {str(task_type): 1.0 for task_type in task_types} + for task_type in task_types: + assert task_type in prob_map + for task_type in prob_map: + if task_type not in task_types: + logger.warning( + f"Task type {task_type} not found in dataset dir. Skip it." + ) + self.prob_map = prob_map + + self.task_type_to_dataset = {} + for task_type in task_types: + # zhutong: use iter to support next() calling, since the dataset itself + # does not implement __next__(). + ds = iter( + PackedJsonlDataset( + str(self.dataset_dir_path.joinpath(task_type)), + seed=seed, + buffer_size=buffer_size, + block_size=block_size, + ) + ) + self.task_type_to_dataset[task_type] = ds + + def __iter__(self) -> Iterator: + while len(self.task_type_to_dataset) > 0: + candidate_task_types = list(self.task_type_to_dataset.keys()) + weights = [self.prob_map[task_type] for task_type in candidate_task_types] + choice = self.rng.choices(candidate_task_types, weights=weights, k=1)[0] + try: + yield next(self.task_type_to_dataset[choice]) + except StopIteration: + # self.task_type_to_dataset.pop(choice) + # logger.debug(f"Task type {choice} finished, drop it") + # yield from self + return diff --git a/smoe/entrypoint/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py similarity index 88% rename from smoe/entrypoint/cpt_fpt.py rename to smoe/entrypoint/cpt/cpt_fpt.py index abb02d2..abebc05 100644 --- a/smoe/entrypoint/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -1,6 +1,7 @@ import os import torch +from torch.distributed.elastic.multiprocessing.errors import record from transformers import ( CONFIG_MAPPING, AutoConfig, @@ -18,6 +19,7 @@ from smoe.callbacks.tensorboard import EnhancedTensorboardCallback from smoe.data.collate_fn import fault_tolerance_data_collator from smoe.data.redpajama import load_streaming_datasets +from smoe.data.streaming import SubDirWeightedPackedJsonlDataset from smoe.metrics.accuracy import compute_metrics from smoe.metrics.preprocess import logits_argmax from smoe.models.llama_moefication.configuration_llama_moe import LlamaMoEConfig @@ -66,9 +68,12 @@ def main(): logger.info(f"Training args: {training_args.to_json_string()}") if training_args.debug_mode: + import torch.distributed as dist + from smoe.utils.debugging import remote_breakpoint - remote_breakpoint() + if dist.get_rank() == 0: + remote_breakpoint() # Detecting last checkpoint. last_checkpoint = None @@ -127,13 +132,14 @@ def main(): # zhutong: this is for debug usage only if training_args.debug_mode: - config.num_hidden_layers = 2 + config.num_hidden_layers = 1 tokenizer_kwargs = { "cache_dir": model_args.cache_dir, "use_fast": model_args.use_fast_tokenizer, "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, + "legacy": True if model_args.use_legacy_tokenizer else False, } if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( @@ -171,6 +177,13 @@ def main(): block_size = min(data_args.block_size, tokenizer.model_max_length) if data_args.prob_map is None: + # slimpajama samples openllama-3B tokenized + # data_args.prob_map = { + # "cc": 0.67, + # "wikipedia": 0.33, + # } + + # redpajama data_args.prob_map = { "en_cc": 0.67, "en_c4": 0.15, @@ -191,22 +204,34 @@ def main(): # } with training_args.main_process_first(desc="dataset map tokenization and grouping"): - lm_datasets = load_streaming_datasets( + lm_datasets = SubDirWeightedPackedJsonlDataset( data_args.dataset_dir, prob_map=data_args.prob_map, - num_proc=data_args.preprocessing_num_workers, - debug_mode=training_args.debug_mode, + seed=training_args.seed, block_size=data_args.block_size, ) + # lm_datasets = load_streaming_datasets( + # data_args.dataset_dir, + # prob_map=data_args.prob_map, + # num_proc=data_args.preprocessing_num_workers, + # debug_mode=training_args.debug_mode, + # block_size=data_args.block_size, + # ) if training_args.do_train: train_dataset = lm_datasets if data_args.max_train_samples is None: raise ValueError("max_train_samples cannot be None") logger.info("training example:") - logger.info( - tokenizer.decode([x["input_ids"] for x in train_dataset.take(1)][0]) - ) + res = None + if hasattr(train_dataset, "take"): + res = tokenizer.decode([x["input_ids"] for x in train_dataset.take(1)][0]) + else: + for x in train_dataset: + input_ids = x["input_ids"] + break + res = tokenizer.decode(input_ids) + logger.info(res) eval_dataset = None if training_args.do_eval: @@ -219,6 +244,11 @@ def main(): else getattr(torch, model_args.torch_dtype) ) ModelClass = MODEL_MAP[model_args.model_type] + + # model = LlamaForCausalLM(config) + # model.half() + # model.to(torch_dtype) + model: LlamaForCausalLM | LlamaMoEForCausalLM = ModelClass.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -229,6 +259,7 @@ def main(): torch_dtype=torch_dtype, low_cpu_mem_usage=True, ) + # train an MoE model from scratch 👇 # config.num_hidden_layers = 20 # model: LlamaMoEForCausalLM = LlamaMoEForCausalLM(config) diff --git a/smoe/entrypoint/cpt_lora.py b/smoe/entrypoint/cpt/cpt_lora.py similarity index 100% rename from smoe/entrypoint/cpt_lora.py rename to smoe/entrypoint/cpt/cpt_lora.py diff --git a/smoe/entrypoint/text_clustering.py b/smoe/entrypoint/text_clustering.py index b9bdead..7be5ef0 100644 --- a/smoe/entrypoint/text_clustering.py +++ b/smoe/entrypoint/text_clustering.py @@ -1,17 +1,21 @@ """ srun -p MoE -n 1 -N 1 --mem 128G python -m smoe.entrypoint.text_clustering --do_train --do_eval -n 16 -m outputs/clustering -o resources/clustering_samples +srun -p MoE -n 1 -N 1 --mem 128G python -u -m smoe.entrypoint.text_clustering --do_eval -n 4 -m outputs/clustering_4 -o resources/clustering_samples_4 1>logs/clustering_4.log 2>&1 & """ import argparse import json -import logging +import os from collections import defaultdict from pathlib import Path +from tqdm import tqdm + from smoe.utils.io import load_jsonlines_iter +from smoe.utils.logging import get_logger from smoe.utils.text_clustering import TextClustering -logger = logging.getLogger(__name__) +logger = get_logger("text_clustering", log_level="INFO") def main(args): @@ -36,14 +40,43 @@ def main(args): model = TextClustering.from_pretrained(args.model_dir) logger.info("Loading contents") - instances = [] + num_tot = 0 for file in files: + for ins in load_jsonlines_iter(file): + num_tot += 1 + + instances = [] + labels = [] + bsz = 32 + batch = [] + bar = tqdm(total=num_tot) + for file_idx, file in enumerate(files): + logger.info(f"file: {file_idx} / {len(files)}") for i, ins in enumerate(load_jsonlines_iter(file)): - instances.append( - {"content": ins["content"], "id": i, "file": file.name} - ) - logger.info("Predicting") - labels = model.predict([ins["content"] for ins in instances]) + if len(batch) == bsz: + preds = model.predict([ins["content"] for ins in batch]) + instances.extend(batch) + labels.extend(preds) + bar.update(len(preds)) + batch.clear() + else: + batch.append( + {"content": ins["content"], "id": i, "file": file.name} + ) + # instances.append( + # {"content": ins["content"], "id": i, "file": file.name} + # ) + # label = model.predict([ins["content"]])[0] + # labels.append(label) + if len(batch) > 0: + preds = model.predict([ins["content"] for ins in batch]) + instances.extend(batch) + labels.extend(preds) + bar.update(len(preds)) + batch.clear() + + logger.info("Predicting finished") + # labels = model.predict([ins["content"] for ins in instances]) logger.info("Dumping results") out_dir = Path(args.output_dir) @@ -63,6 +96,9 @@ def main(args): if __name__ == "__main__": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["LOGLEVEL"] = "INFO" + parser = argparse.ArgumentParser() parser.add_argument("--do_train", action="store_true") parser.add_argument("--do_eval", action="store_true") diff --git a/smoe/models/llama_moefication/modeling_llama_moe.py b/smoe/models/llama_moefication/modeling_llama_moe.py index 1146525..2c0749e 100644 --- a/smoe/models/llama_moefication/modeling_llama_moe.py +++ b/smoe/models/llama_moefication/modeling_llama_moe.py @@ -29,13 +29,14 @@ @dataclass class MoEDecoderLayerOutput(ModelOutput): + # zhutong: do not change the order of these fields!! hidden_states: Optional[torch.FloatTensor] = None + balance_loss: Optional[float] = None + num_dropped_tokens: Optional[torch.Tensor] = None + gate_load: Optional[list[torch.Tensor]] = None + gate_importance: Optional[list[torch.Tensor]] = None self_attn_weights: Optional[torch.FloatTensor] = None present_key_value: Optional[torch.FloatTensor] = None - balance_loss: Optional[float] = None - num_dropped_tokens: Optional[int] = None - gate_load: Optional[list] = None - gate_importance: Optional[list] = None @dataclass @@ -50,7 +51,7 @@ class BaseMoEModelOutputWithPast(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None balance_loss: Optional[float] = None - num_dropped_tokens: Optional[Tuple[int]] = None + num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None gate_load: Optional[Tuple[list]] = None gate_importance: Optional[Tuple[list]] = None @@ -134,15 +135,23 @@ def forward( mlp_outs: MoEMlpOutput = self.mlp(hidden_states) hidden_states = residual + mlp_outs.hidden_states - outputs = MoEDecoderLayerOutput( - hidden_states=hidden_states, - balance_loss=mlp_outs.balance_loss, - self_attn_weights=self_attn_weights if output_attentions else None, - present_key_value=present_key_value if use_cache else None, - num_dropped_tokens=mlp_outs.num_dropped_tokens, - gate_load=mlp_outs.gate_load, - gate_importance=mlp_outs.gate_importance, + outputs = ( + hidden_states, + mlp_outs.balance_loss, + mlp_outs.num_dropped_tokens, + mlp_outs.gate_load, + mlp_outs.gate_importance, ) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + for i, _o in enumerate(outputs): + if not isinstance(_o, torch.Tensor): + raise RuntimeError( + f"outputs[{i}]({type(_o)}) should be torch.Tensor to support grad ckpt" + ) return outputs @@ -313,17 +322,15 @@ def custom_forward(*inputs): return custom_forward - layer_outputs: MoEDecoderLayerOutput = ( - torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) + layer_outputs: tuple = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, ) else: - layer_outputs: MoEDecoderLayerOutput = decoder_layer( + layer_outputs: tuple = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -331,6 +338,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, ) + layer_outputs = MoEDecoderLayerOutput(*layer_outputs) hidden_states = layer_outputs.hidden_states if layer_outputs.balance_loss is not None: @@ -432,7 +440,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - **kwargs + **kwargs, ): output_attentions = ( output_attentions @@ -490,7 +498,7 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, num_dropped_tokens=outputs.num_dropped_tokens, - balance_loss=outputs.balance_loss.item(), + balance_loss=outputs.balance_loss, gate_load=outputs.gate_load, gate_importance=outputs.gate_importance, ) diff --git a/smoe/modules/moe/moe_calculators.py b/smoe/modules/moe/moe_calculators.py index 0f5abdd..02c0752 100644 --- a/smoe/modules/moe/moe_calculators.py +++ b/smoe/modules/moe/moe_calculators.py @@ -32,7 +32,7 @@ def forward( # fmt: off """正向传播""" """临时变量""" - batch_size = topK_indices.size(0) + batch_size = topK_indices.size(0) # topK_indices: (bsz*seq_len, num_selects) num_selects = topK_indices.size(1) topK_indices = topK_indices.flatten() # shape(batch_size*num_selects) topK_scores = topK_scores.flatten() # shape(batch_size*num_selects) @@ -58,6 +58,7 @@ def forward( expert_outputs = [self.experts(split_x[i], i) for i in range(self.num_experts) if split_x[i].shape[0] > 0] """重组各个专家的输出,并进行加权""" + # (bsz*seq_len*num_selects, hidden_size) cat_expert_outputs = torch.cat(expert_outputs, 0) # 拼接专家输出 output_dim = cat_expert_outputs.size(1) if self.multiply_gate_scores: @@ -65,7 +66,7 @@ def forward( zeros = torch.zeros((batch_size, output_dim), device=cat_expert_outputs.device, dtype=cat_expert_outputs.dtype) y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs) # 按照对应的batch编号,添加输出 - return CalculatorOutput(hidden_states=y) + return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0)) # fmt: on @@ -144,4 +145,6 @@ def forward(self, x, topK_indices, topK_scores, **kwargs) -> CalculatorOutput: # 乘权重 y = torch.mul(y, topK_scores.reshape(-1, 1)) - return CalculatorOutput(hidden_states=y, num_dropped_tokens=num_dropped_tokens) + return CalculatorOutput( + hidden_states=y, num_dropped_tokens=torch.tensor(num_dropped_tokens) + ) diff --git a/smoe/modules/moe/moe_gates.py b/smoe/modules/moe/moe_gates.py index 61100a8..92aeb5f 100644 --- a/smoe/modules/moe/moe_gates.py +++ b/smoe/modules/moe/moe_gates.py @@ -98,6 +98,7 @@ def forward(self, x): noise_mm = self.weight_noise(x) # 噪声矩阵计算结果 noise_control = self.softplus(noise_mm) + self.noise_epsilon # 控制器得到的噪声增加量 logits_noise = torch.randn_like(logits_gate) * noise_control # noise附加的权重 + # logits_noise = noise_control * gumbel_rsample(logits_gate.shape, device=logits_gate.device).to(logits_gate) logits = logits_gate + logits_noise # 最终权重 else: logits = logits_gate # 最终权重,shape(batch_size, num_experts) @@ -110,6 +111,9 @@ def forward(self, x): """专家平衡选择""" # zhutong: 不要把`self.training`写在里面的if语句中,否则会导致eval模式下gate loss输出值设备不匹配的错误 + load = torch.tensor(-1.0) + importance = torch.tensor(-1.0) + if self.training and self.use_balance: """计算importance""" zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) @@ -142,8 +146,8 @@ def forward(self, x): "topK_indices": top_k_indices, "topK_scores": top_k_scores, "balance_loss": balance_loss, - "load": load.tolist(), - "importance": importance.tolist(), + "load": load, + "importance": importance, } def forward_return_scores(self, x): @@ -224,7 +228,7 @@ def __init__( gate_network="mlp", use_softmax=True, use_balance=True, - balance_loss_weight=1e-1, + balance_loss_weight=1e-2, add_noise=True, ): super(SwitchBalancedGate, self).__init__() @@ -276,8 +280,8 @@ def forward(self, x): "topK_scores": top1_scores, "expert_batch_size": load.tolist(), "balance_loss": balance_loss, - "load": load_mean.tolist(), - "importance": importance_mean.tolist(), + "load": load_mean, + "importance": importance_mean, } def reset_gate_network(self): diff --git a/smoe/modules/moe/moe_layers.py b/smoe/modules/moe/moe_layers.py index 064a54e..baa2b24 100644 --- a/smoe/modules/moe/moe_layers.py +++ b/smoe/modules/moe/moe_layers.py @@ -47,8 +47,8 @@ def forward(self, x) -> MoEMlpOutput: hidden_states=y, balance_loss=gate_outputs.get("balance_loss"), num_dropped_tokens=calc_outs.num_dropped_tokens, - gate_load=gate_outputs.get("load"), - gate_importance=gate_outputs.get("importance"), + gate_load=gate_outputs.get("load", torch.tensor(-1)), + gate_importance=gate_outputs.get("importance", torch.tensor(-1)), ) def set_num_selects(self, num_selects): diff --git a/smoe/trainer/llama_lr_scheduling.py b/smoe/trainer/llama_lr_scheduling.py index 5eea551..ff95812 100644 --- a/smoe/trainer/llama_lr_scheduling.py +++ b/smoe/trainer/llama_lr_scheduling.py @@ -76,7 +76,6 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_warmup_steps: int, num_training_steps: int, num_cycles: float, - learning_rate: float, final_lr_portion: float, ): if current_step < num_warmup_steps: @@ -85,7 +84,7 @@ def _get_cosine_schedule_with_warmup_lr_lambda( max(1, num_training_steps - num_warmup_steps) ) return max( - learning_rate * final_lr_portion, + final_lr_portion, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), ) @@ -110,7 +109,6 @@ def create_scheduler( num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=0.5, - learning_rate=self.args.learning_rate, final_lr_portion=self.args.final_lr_portion, ) last_epoch = -1 @@ -147,6 +145,7 @@ def training_step( return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): + # zhutong: return outputs loss, outputs = self.compute_loss(model, inputs, return_outputs=True) if self.args.n_gpu > 1: @@ -160,43 +159,9 @@ def training_step( else: self.accelerator.backward(loss) + # zhutong: return outputs return loss.detach() / self.args.gradient_accumulation_steps, outputs - def compute_loss(self, model, inputs, return_outputs=False): - """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - - Subclass and override for custom behavior. - """ - if self.label_smoother is not None and "labels" in inputs: - labels = inputs.pop("labels") - else: - labels = None - outputs = model(**inputs) - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index] - - if labels is not None: - if ( - unwrap_model(model)._get_name() - in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() - ): - loss = self.label_smoother(outputs, labels, shift_labels=True) - else: - loss = self.label_smoother(outputs, labels) - else: - if isinstance(outputs, dict) and "loss" not in outputs: - raise ValueError( - "The model did not return a loss from the inputs, only the following keys: " - f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." - ) - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - - return (loss, outputs) if return_outputs else loss - def _maybe_log_save_evaluate( self, tr_loss, @@ -227,10 +192,12 @@ def _maybe_log_save_evaluate( 4, ) logs["learning_rate"] = self._get_learning_rate() - logs["num_dropped_tokens"] = num_dropped_tokens - logs["gate_load"] = gate_load - logs["gate_importance"] = gate_importance - logs["balance_loss"] = balance_loss + logs["num_dropped_tokens"] = [x.item() for x in num_dropped_tokens] + logs["gate_load"] = [x.detach().cpu().tolist() for x in gate_load] + logs["gate_importance"] = [ + x.detach().cpu().tolist() for x in gate_importance + ] + logs["balance_loss"] = balance_loss.item() self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step @@ -411,6 +378,8 @@ def _inner_training_loop( self._created_lr_scheduler = False if self.is_deepspeed_enabled: + # # zhutong: move model to cuda device for fused optim init + # self.model.to(self.accelerator.device) self.optimizer, self.lr_scheduler = deepspeed_init( self, num_training_steps=max_steps ) @@ -742,20 +711,24 @@ def _inner_training_loop( args, self.state, self.control ) + keys = [ + "balance_loss", + "num_dropped_tokens", + "gate_load", + "gate_importance", + ] + _result_dict = {key: None for key in keys} + for key in keys: + if hasattr(model_training_outputs, key): + _result_dict[key] = getattr(model_training_outputs, key) + self._maybe_log_save_evaluate( tr_loss, model, trial, epoch, ignore_keys_for_eval, - balance_loss=getattr(model_training_outputs, "balance_loss"), - num_dropped_tokens=getattr( - model_training_outputs, "num_dropped_tokens" - ), - gate_load=getattr(model_training_outputs, "gate_load"), - gate_importance=getattr( - model_training_outputs, "gate_importance" - ), + **_result_dict, ) else: self.control = self.callback_handler.on_substep_end( diff --git a/smoe/utils/config.py b/smoe/utils/config.py index c2cf0c1..0fd8594 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -87,6 +87,15 @@ class ModelArguments: ) }, ) + use_legacy_tokenizer: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use the legacy tokenization or not. Only has an effect when" + " using a sentencepiece-based tokenizer. Ref: https://github.com/huggingface/transformers/pull/24565" + ) + }, + ) model_revision: str = field( default="main", metadata={ diff --git a/smoe/utils/debugging.py b/smoe/utils/debugging.py index 69ed9a3..366be43 100644 --- a/smoe/utils/debugging.py +++ b/smoe/utils/debugging.py @@ -1,4 +1,5 @@ import debugpy +import torch.distributed as dist def remote_breakpoint(host: str = "0.0.0.0", port: int = 5678): @@ -40,6 +41,10 @@ def remote_breakpoint(host: str = "0.0.0.0", port: int = 5678): After the program starts and encounters the breakpoint, you could remote attach the debugger. """ - debugpy.listen((host, port)) - debugpy.wait_for_client() - breakpoint() + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + if rank == 0: + debugpy.listen((host, port)) + debugpy.wait_for_client() + breakpoint() + dist.barrier() diff --git a/smoe/utils/io.py b/smoe/utils/io.py index 14b3955..43dffb4 100644 --- a/smoe/utils/io.py +++ b/smoe/utils/io.py @@ -34,21 +34,33 @@ def load_compressed_file_gz(path): # gz return data -def load_jsonlines_iter(filepath): - with open(filepath, "rt", encoding="utf8") as fin: - for line in fin: - yield json.loads(line) +class load_jsonlines_iter: + def __init__(self, filepath, start_from: int = None) -> None: + self.fin = open(filepath, "r", encoding="utf8") + if start_from: + self.fin.seek(start_from, os.SEEK_SET) + + def tell(self): + return self.fin.tell() + + def __iter__(self): + for line in self.fin: + try: + yield json.loads(line) + except json.JSONDecodeError: + pass + self.fin.close() def load_jsonlines(filepath): data = [] - with open(filepath, "rt", encoding="utf8") as fin: + with open(filepath, "r", encoding="utf8") as fin: for line in fin: data.append(json.loads(line)) return data def dump_jsonlines(obj, filepath, **kwargs): - with open(filepath, "wt", encoding="utf8") as fout: + with open(filepath, "w", encoding="utf8") as fout: for ins in obj: fout.write(f"{json.dumps(ins, ensure_ascii=False, **kwargs)}\n") diff --git a/smoe/utils/random.py b/smoe/utils/random.py new file mode 100644 index 0000000..c3d6a47 --- /dev/null +++ b/smoe/utils/random.py @@ -0,0 +1,14 @@ +import random +import string + + +def get_random_string(length: int = 8) -> str: + """Generate a unique random string. + + Args: + length (int, optional): Length of the random string. Defaults to 16. + + Returns: + str: A unique random string. + """ + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) diff --git a/smoe/utils/text_clustering.py b/smoe/utils/text_clustering.py index 0d6624e..09f000d 100644 --- a/smoe/utils/text_clustering.py +++ b/smoe/utils/text_clustering.py @@ -20,7 +20,7 @@ def num_clusters(self) -> int: return self.kmeans.n_clusters def encode_emb(self, sentences: list[str]) -> np.ndarray: - arr: np.ndarray = self.emb.encode(sentences=sentences) + arr: np.ndarray = self.emb.encode(sentences=sentences, show_progress_bar=False) return arr def fit_emb(self, emb: np.ndarray): diff --git a/smoe/utils/vars.py b/smoe/utils/vars.py index 6e09ab1..b68c481 100644 --- a/smoe/utils/vars.py +++ b/smoe/utils/vars.py @@ -2,3 +2,4 @@ BEST_MODEL_CKPT_DIR = "best" MIDDLE_MODEL_CKPT_DIR = "middle" CLUSTERING_MODEL_NAME = "clustering.model" +JSONL_DATASET_CACHE_NAME = "jsonl_dataset-{}.bin" diff --git a/smoe/utils/visualization/visualize.py b/smoe/utils/visualization/visualize.py index 0f712be..f0d0511 100644 --- a/smoe/utils/visualization/visualize.py +++ b/smoe/utils/visualization/visualize.py @@ -191,6 +191,9 @@ def visualize_swiglu_output( def find_factors_with_minimal_sum(number): + if number == 1: + return (1, 1) + # Initialize variables to keep track of the factors with the minimal sum min_sum = float("inf") min_factors = None @@ -286,8 +289,8 @@ def vis_tuple_heatmaps(tensors: tuple[torch.FloatTensor]): axes = axes.reshape(*img_grid) for i in range(data.shape[0]): ax = axes[i // img_grid[1], i % img_grid[1]] - im = ax.imshow( - data[i].cpu().reshape(*shape).float().numpy(), + ax.imshow( + data[i].cpu().reshape(*shape).float().detach().numpy(), cmap=cmap, interpolation="nearest", # vmin=0.0, diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/test_aggregation.py b/tests/data/test_aggregation.py new file mode 100644 index 0000000..f263772 --- /dev/null +++ b/tests/data/test_aggregation.py @@ -0,0 +1,21 @@ +from smoe.data.aggregation import group_instances + + +def test_group_instances(): + instances = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + ] + results = group_instances(instances, block_size=4) + assert results == [ + {"input_ids": [1, 2, 3, 1], "labels": [4, 5, 6, 4]}, + {"input_ids": [2, 3, 1, 2], "labels": [5, 6, 4, 5]}, + {"input_ids": [3, 1, 2, 3], "labels": [6, 4, 5, 6]}, + ] + + +if __name__ == "__main__": + test_group_instances() diff --git a/tests/data/test_redpajama.py b/tests/data/test_redpajama.py new file mode 100644 index 0000000..c26e34c --- /dev/null +++ b/tests/data/test_redpajama.py @@ -0,0 +1,57 @@ +import time +from collections import defaultdict +from pathlib import Path + +from torch.utils.data import DataLoader + +from smoe.data.redpajama import load_streaming_datasets +from smoe.utils.io import dump_jsonlines, load_jsonlines + + +def test_load_streaming_datasets(): + output_dir = Path("/mnt/petrelfs/zhutong/smoe/resources/data_test_with_task_type") + output_dir.mkdir(parents=True, exist_ok=True) + # dataset_dir = Path("resources/data_test") + dataset_dir = Path("resources/data_test_with_task_type") + + # # update new dataset with task type + # for subtask_dir in dataset_dir.glob("*"): + # task_type = subtask_dir.stem + # subtask_out_dir = output_dir.joinpath(task_type) + # subtask_out_dir.mkdir(parents=True, exist_ok=True) + # for file in subtask_dir.glob("*.jsonl"): + # data = load_jsonlines(file) + # for ins in data: + # ins["src"] = task_type + # dump_jsonlines(data, subtask_out_dir.joinpath(file.name)) + + dataset = load_streaming_datasets( + str(dataset_dir), + prob_map={"en_arxiv": 0.5, "en_book": 0.2, "en_c4": 0.3}, + block_size=2048, + ) + num_ds = 0 + num_src = defaultdict(lambda: 0) + + start = time.time() + for ds in iter(dataset): + num_ds += 1 + # print(num_ds, ds["src"]) + # num_src[ds["src"]] += 1 + time_span = time.time() - start + print(num_ds) + print(dict(num_src)) + print(f"Time (ins/s): {num_ds / time_span:.2f}" "") + + """ + block_size: -1 + {'en_arxiv': 400, 'en_c4': 214} + Time (ins/s): 64.05 + + block_size: 2048 + Time (ins/s): 59.94 + """ + + +if __name__ == "__main__": + test_load_streaming_datasets() diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 0000000..cae1582 --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,105 @@ +import tempfile +import time +from collections import defaultdict +from pathlib import Path + +import pytest + +from smoe.data.streaming import JsonlDataset, SubDirWeightedPackedJsonlDataset +from smoe.utils.io import load_jsonlines + + +def test_jsonl_dataset(): + def _get_num_iter(ds): + num_ins = 0 + for _ in ds: + num_ins += 1 + return num_ins + + filepath = "/mnt/petrelfs/zhutong/smoe/resources/redpajama/en_arxiv/head2k.jsonl" + data = load_jsonlines(filepath) + + dataset = JsonlDataset(filepath, buffer_size=16) + assert len(data) == _get_num_iter(dataset) + + num_skip = 50 + dataset = JsonlDataset(filepath, num_skip=num_skip) + assert len(data) - num_skip == _get_num_iter(dataset) + + dataset = JsonlDataset(filepath, buffer_size=6) + num_ins = 0 + for _ in dataset: + num_ins += 1 + if num_ins == num_skip: + break + start_from = dataset.load_fh.tell() + temp_dir = tempfile.mkdtemp() + path = dataset.save_pretrained(temp_dir) + + new_dataset = JsonlDataset.from_pretrained(temp_dir) + + +@pytest.mark.skipif( + Path("resources/data_test_with_task_type").exists(), + reason="Test data dir not found", +) +def test_subdir_weighted_pack_with_type(): + dataset = SubDirWeightedPackedJsonlDataset( + "resources/data_test_with_task_type", + prob_map={"en_arxiv": 0.5, "en_book": 0.2, "en_c4": 0.3}, + buffer_size=1000, + block_size=2048, + ) + num_ds = 0 + num_src = defaultdict(lambda: 0) + + start = time.time() + for ds in iter(dataset): + num_ds += 1 + # print(num_ds, ds["src"]) + # num_src[ds["src"]] += 1 + time_span = time.time() - start + print(num_ds) + print(dict(num_src)) + print(f"Time (ins/s): {num_ds / time_span:.2f}" "") + + """ + block_size: -1 + {'en_arxiv': 400, 'en_c4': 244} + Time (ins/s): 1075.88 + 16.797501951600314 times faster than hf-datasets! + + block_size: 2048, buffer_size: 1000 + Time (ins/s): 283.53 + 4.73023023023023 times faster than hf-datasets! + """ + + +def test_weighted_streaming(): + prob_map = { + "en_cc": 0.67, + "en_c4": 0.15, + "github": 0.045, + "en_wikipedia": 0.045, + "en_book": 0.045, + "en_arxiv": 0.025, + "en_stack": 0.02, + } + lm_datasets = SubDirWeightedPackedJsonlDataset( + "/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed", + prob_map=prob_map, + seed=1227, + block_size=2048, + ) + for ds in lm_datasets: + print(ds["input_ids"]) + break + for ds in lm_datasets: + print(ds["input_ids"]) + break + + +if __name__ == "__main__": + # test_jsonl_dataset() + # test_subdir_weighted_pack_with_type() + test_weighted_streaming()