From 2d0c943357324e8d77ce8b77ff75232e21edaf23 Mon Sep 17 00:00:00 2001 From: Rissy Ran Date: Mon, 13 Jan 2025 22:27:19 +0000 Subject: [PATCH] Add more 8x22b pre-training and decoding tests --- .../tpu/mixtral/8x22b/2_test_mixtral.sh | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh index f0ca70cd4..2cc9c8afc 100644 --- a/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh +++ b/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh @@ -20,19 +20,33 @@ if [ -z "${BASE_OUTPUT_PATH}" ]; then fi export DATASET_PATH=gs://maxtext-dataset - -# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py` -export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items - export TOKENIZER_PATH=assets/tokenizer.mistral-v3 -# TODO(ranran): enable the fine-tuning, decoding, and forward_pass_logit_checker tests once b/380148614 has been fixed - # Run pre-training without load_parameters_path - megablox implementation python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ - run_name=pre_training per_device_batch_size=1 enable_checkpointing=false \ + run_name=pre_training_megablox per_device_batch_size=4 enable_checkpointing=false \ model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ - steps=5 max_target_length=1024 async_checkpointing=false \ + steps=5 max_target_length=128 async_checkpointing=false \ tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \ weight_dtype=bfloat16 megablox=True + +# Run pre-training without load_parameters_path - matmul implementation +python3 MaxText/train.py MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ + run_name=pre_training_matmul per_device_batch_size=4 enable_checkpointing=false \ + model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ + steps=5 max_target_length=128 async_checkpointing=false \ + tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \ + weight_dtype=bfloat16 megablox=False + +# Run pre-training without load_parameters_path - dropping implementation +python3 MaxText/train.py MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ + run_name=pre_training_dropping per_device_batch_size=4 enable_checkpointing=false \ + model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ + steps=5 max_target_length=128 async_checkpointing=false \ + tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \ + weight_dtype=bfloat16 megablox=False capacity_factor=1.25 + +# TODO(ranran): Add decoding, fine-tuning, and forward_pass_logit_checker tests after b/384580048