From 90107c26e95607b75f5ad7e45c86349ad256b956 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Mon, 16 Dec 2024 19:48:56 +0000 Subject: [PATCH] Added TFDS config. --- benchmarks/benchmark_runner.py | 2 + benchmarks/maxtext_trillium_model_configs.py | 40 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 291672ddb..bf942bda3 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -89,6 +89,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data', 'llama2_70b_4096_pw_long_run', 'llama2_70b_4096_real_data_pw_long_run', + 'llama2_70b_4096_pw_rd_tfds ', 'llama2_70b_4096_synthetic_pw_lr', 'llama2_70b_4096_synthetic', 'llama3_70b_8192', @@ -108,6 +109,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data ' 'llama2_70b_4096_pw_long_run ' 'llama2_70b_4096_real_data_pw_long_run ' + 'llama2_70b_4096_pw_rd_tfds ' 'llama2_70b_4096_synthetic_pw_lr ' 'llama2_70b_4096_synthetic ' 'llama3_1_405b_8192_fsdp_dcn ' diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 61c8b1883..74f16203c 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -462,6 +462,45 @@ class MaxTextModel: ), ) +llama2_70b_4096_pw_rd_tfds = MaxTextModel( + model_name="llama2_70b_4096_pw_rd_tfds", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + + llama3_8b_8192 = MaxTextModel( model_name="llama3-8b-8192", model_type="llama3-8b", @@ -760,6 +799,7 @@ class MaxTextModel: llama2_70b_4096_pw_long_run, llama2_70b_4096_real_data, llama2_70b_4096_real_data_pw_long_run, + llama2_70b_4096_pw_rd_tfds, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama2_70b_4096_synthetic_pw_lr,