From b1c4730d97db16c9a5ea343b093ce70dc5127c64 Mon Sep 17 00:00:00 2001 From: Rissy Ran Date: Wed, 27 Nov 2024 19:37:11 +0000 Subject: [PATCH] add more MoE tests --- end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh index bc112677b..fa6acdd35 100644 --- a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh +++ b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh @@ -31,15 +31,21 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned_ckpt/checkpoints/0/item # TODO(ranran): add decoding test for megablox implementation python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" megablox=False scan_layers=false +# Run decoding with converted ckpt - dropping implementation +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" megablox=False scan_layers=false capacity_factor=1.25 + # Test whether the forward pass logits match the golden logits - matmul implementation -python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=8 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False scan_layers=false --atol=3 --rtol=1 --token_size=4 +python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=4 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False scan_layers=false --atol=3 --rtol=1 --token_size=4 # Test whether the forward pass logits match the golden logits - megablox implementation # TODO(ranran): investigate the root cause of the excessive tolerance -python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=megablox_forward_pass_test per_device_batch_size=8 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_fsdp_parallelism=64 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 scan_layers=false --atol=20 --rtol=10 --token_size=4 +python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=megablox_forward_pass_test per_device_batch_size=4 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_fsdp_parallelism=64 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 scan_layers=false --atol=20 --rtol=10 --token_size=4 + +# Run pre-training - megablox implementation +python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 -# Run fine-tuning - megablox implementation -python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning per_device_batch_size=8 model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=10 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 checkpoint_period=5 attention=flash dtype=bfloat16 weight_dtype=bfloat16 +# Run pre-training - matmul implementation +python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False -# Run pre-training without load_parameters_path - megablox implementation -python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=pre_training per_device_batch_size=8 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 +# Run pre-training - dropping implementation +python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False capacity_factor=1