From 1dcb404dd94e18eb01bcdc1781b6db6352a6724a Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Thu, 16 Jan 2025 00:16:50 +0000 Subject: [PATCH] Use an async checkpointer to initialize JDI in tests --- MaxText/layers/linears.py | 5 +---- end_to_end/tpu/gemma/7b/2_test_gemma.sh | 9 +++++---- end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh | 12 +++++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index d1efeca35..1820a26bc 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -432,10 +432,7 @@ def gmm(inputs, kernel, group_sizes): ) else: if self.quant is not None: - raise NotImplementedError( - "Quantization is not yet supported with ragged_dot, please set" - " megablox=True" - ) + raise NotImplementedError("Quantization is not yet supported with ragged_dot, please set" " megablox=True") output = jax.lax.ragged_dot( lhs=inputs, rhs=kernel, diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh index 549c78d8a..41a8ddd27 100644 --- a/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -36,16 +36,17 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_gemma.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items +export ASYNC_CHECKPOINTING=True # True so that the jax distributed system is initialized # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b @@ -57,7 +58,7 @@ export PARAM_RUN_NAME=param_chkpt python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. diff --git a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh index 19b271653..8a820b650 100644 --- a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh +++ b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh @@ -40,16 +40,18 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_llama2_70b.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items +export ASYNC_CHECKPOINTING=true # True so that jax distributed system is initialized + # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} @@ -61,7 +63,7 @@ export PARAM_RUN_NAME=param_chkpt python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama2-70b -python3 MaxText/tests/forward_pass_logit_checker.py --atol=0.2 --rtol=0.2 MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=false +python3 MaxText/tests/forward_pass_logit_checker.py --atol=0.2 --rtol=0.2 MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=${ASYNC_CHECKPOINTING}