Skip to content

Commit

Permalink
Merge pull request AI-Hypercomputer#1164 from AI-Hypercomputer:more_8…
Browse files Browse the repository at this point in the history
…x22b_tests

PiperOrigin-RevId: 715162042
  • Loading branch information
maxtext authors committed Jan 14, 2025
2 parents 6304b46 + 2d0c943 commit bfc4264
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,33 @@ if [ -z "${BASE_OUTPUT_PATH}" ]; then
fi

export DATASET_PATH=gs://maxtext-dataset

# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py`
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items

export TOKENIZER_PATH=assets/tokenizer.mistral-v3

# TODO(ranran): enable the fine-tuning, decoding, and forward_pass_logit_checker tests once b/380148614 has been fixed

# Run pre-training without load_parameters_path - megablox implementation
python3 MaxText/train.py MaxText/configs/base.yml \
base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \
run_name=pre_training per_device_batch_size=1 enable_checkpointing=false \
run_name=pre_training_megablox per_device_batch_size=4 enable_checkpointing=false \
model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \
steps=5 max_target_length=1024 async_checkpointing=false \
steps=5 max_target_length=128 async_checkpointing=false \
tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \
weight_dtype=bfloat16 megablox=True

# Run pre-training without load_parameters_path - matmul implementation
python3 MaxText/train.py MaxText/configs/base.yml \
base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \
run_name=pre_training_matmul per_device_batch_size=4 enable_checkpointing=false \
model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \
steps=5 max_target_length=128 async_checkpointing=false \
tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \
weight_dtype=bfloat16 megablox=False

# Run pre-training without load_parameters_path - dropping implementation
python3 MaxText/train.py MaxText/configs/base.yml \
base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \
run_name=pre_training_dropping per_device_batch_size=4 enable_checkpointing=false \
model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \
steps=5 max_target_length=128 async_checkpointing=false \
tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \
weight_dtype=bfloat16 megablox=False capacity_factor=1.25

# TODO(ranran): Add decoding, fine-tuning, and forward_pass_logit_checker tests after b/384580048

0 comments on commit bfc4264

Please sign in to comment.