From 4642cfb8ea6f80baa48e5296b0934011b7716624 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 25 Oct 2024 01:53:35 +0000 Subject: [PATCH 01/24] Attempt to add JAX functional tests --- dags/multipod/configs/jax_tests_gce_config.py | 76 +++++++++++++++ dags/multipod/configs/jax_tests_gke_config.py | 0 dags/multipod/jax_functional_tests.py | 92 +++++++++++++++++++ dags/test_owner.py | 3 + deployment/cloud_composer.auto.tfvars | 10 +- .../.terraform.lock.hcl | 21 +++++ deployment/modules/composer_env/main.tf | 2 +- deployment/provider.tf | 2 +- 8 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 dags/multipod/configs/jax_tests_gce_config.py create mode 100644 dags/multipod/configs/jax_tests_gke_config.py create mode 100644 dags/multipod/jax_functional_tests.py create mode 100644 deployment/deployment_composer_env/.terraform.lock.hcl diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py new file mode 100644 index 00000000..159ee69b --- /dev/null +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -0,0 +1,76 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for JAX tests.""" + +from xlml.apis import gcp_config, metric_config, task, test_config +from dags import test_owner, gcs_bucket +from dags.multipod.configs import common +from dags.vm_resource import TpuVersion, Project, RuntimeVersion +import datetime + +PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value +RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value +# PROJECT_NAME = Project.TPU_PROD_ENV_MULTIPOD.value +# RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value + +def get_jax_distributed_initialize_config( + tpu_version: TpuVersion, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + test_name: str, + test_mode: common.SetupMode, + project_name: str = PROJECT_NAME, + runtime_version: str = RUNTIME_IMAGE, + network: str = "default", + subnetwork: str = "default", + is_tpu_reserved: bool = True, + automated_test: bool = True, + num_slices: int = 1, + base_output_dir: str = gcs_bucket.BASE_OUTPUT_DIR, +): + set_up_cmds = 'pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; python -m pip install fabric;' + run_model_cmds = ( + ( + "python -m 'print(10111); import jax; jax.distributed.initialize(); print(20111);'" + ), + ) # TODO: Do we need to verify success or will the command fail if execution fails? + job_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + runtime_version=runtime_version, + reserved=is_tpu_reserved, + network=network, + subnetwork=subnetwork, + ), + test_name=test_name, + set_up_cmds=set_up_cmds, + run_model_cmds=run_model_cmds, + timeout=datetime.timedelta(minutes=time_out_in_min), + task_owner=test_owner.AKANKSHA_G, + num_slices=num_slices, + ) + + job_gcp_config = gcp_config.GCPConfig( + project_name=project_name, + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, # TODO: can remove? + ) + + return task.run_queued_resource_test( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/dags/multipod/configs/jax_tests_gke_config.py b/dags/multipod/configs/jax_tests_gke_config.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py new file mode 100644 index 00000000..ddd5147f --- /dev/null +++ b/dags/multipod/jax_functional_tests.py @@ -0,0 +1,92 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run JAX tests.""" + +import datetime +from airflow import models +from dags import composer_env +from dags.vm_resource import TpuVersion, Zone +from dags.multipod.configs import jax_tests_gce_config +from dags.multipod.configs.common import SetupMode + +# Run once a day at am UTC ( am PST) # TODO time +SCHEDULED_TIME = "0 9 * * *" if composer_env.is_prod_env() else None # TODO time + +with models.DAG( + dag_id="jax_functional_tests", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "jax"], + start_date=datetime.datetime(2024, 10, 23), + catchup=False, +) as dag: + default_test_name = "jax" + test_mode = SetupMode.NIGHTLY + + # v4 + jax_nightly_1slice_v4_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + test_name=default_test_name, + test_mode=test_mode, + ) + + jax_nightly_2slice_v4_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=2, + test_name=default_test_name, + test_mode=test_mode, + ) + + # v5p + # v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value + # v5p_network = V5_NETWORKS + # v5p_subnetwork = V5P_SUBNETWORKS + # v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value + + # jax_nightly_1slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + # tpu_version=TpuVersion.V5P, + # tpu_cores=8, + # tpu_zone=Zone.US_EAST5_A.value, + # runtime_version=v5p_runtime_version, + # project_name=v5p_project_name, + # time_out_in_min=60, + # is_tpu_reserved=True, + # test_name=default_test_name, + # test_mode=test_mode, + # network=v5p_network, + # subnetwork=v5p_subnetwork, + # ) + + # jax_nightly_2slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + # tpu_version=TpuVersion.V5P, + # tpu_cores=8, + # num_slices=2, + # tpu_zone=Zone.US_EAST5_A.value, + # runtime_version=v5p_runtime_version, + # project_name=v5p_project_name, + # time_out_in_min=60, + # is_tpu_reserved=True, + # test_name=default_test_name, + # test_mode=test_mode, + # network=v5p_network, + # subnetwork=v5p_subnetwork, + # ) diff --git a/dags/test_owner.py b/dags/test_owner.py index dcfccc9d..3e9d8d91 100644 --- a/dags/test_owner.py +++ b/dags/test_owner.py @@ -66,3 +66,6 @@ class Team(enum.Enum): # FRAMEWORK QINY_Y = "Qinyi Y." + +# JAX +AKANKSHA_G = "Akanksha G." diff --git a/deployment/cloud_composer.auto.tfvars b/deployment/cloud_composer.auto.tfvars index 0a307c21..c7cfa022 100644 --- a/deployment/cloud_composer.auto.tfvars +++ b/deployment/cloud_composer.auto.tfvars @@ -6,11 +6,7 @@ project_config = { environment_config = [ { - environment_name = "ml-automation-solutions" - service_account_id = "ml-auto-solutions" - }, - { - environment_name = "ml-automation-solutions-dev" - service_account_id = "ml-auto-solutions-dev" + environment_name = "akshu-test" + service_account_id = "akshu-dev" } -] \ No newline at end of file +] diff --git a/deployment/deployment_composer_env/.terraform.lock.hcl b/deployment/deployment_composer_env/.terraform.lock.hcl new file mode 100644 index 00000000..79b9ee58 --- /dev/null +++ b/deployment/deployment_composer_env/.terraform.lock.hcl @@ -0,0 +1,21 @@ +# This file is maintained automatically by "terraform init". +# Manual edits may be lost in future updates. + +provider "registry.terraform.io/hashicorp/google-beta" { + version = "6.7.0" + hashes = [ + "h1:w5bxwp3tSvAViwW/14MyjaWXbA3bdSdx0nxnNv0OOEw=", + "zh:0def181a8781c13f002bf77afb83ce600bd5c08e4324f5d2c2e4d60a8b41dd5c", + "zh:323453165fa8c69a4400c0c14750bce9ac9729f2313966e983a13aa8a0d68f9e", + "zh:5507893e7115d0702de0a87f583d09639e082c825ec1e5bd1e4317061d5c7b04", + "zh:74505c85aaa9ef70491bdb708e0aee459fb1b85848d170217a6bf0c4cd2edf5c", + "zh:81d7f41eda8547e140304320fa0d188be26982234831c8405db5a91aef20e478", + "zh:95cd1322406f7cd97e6e4c0ddedf60060fa70481c9134871148ef6dd71d2254d", + "zh:a4ba2b745d35a7e09d318f30f0057c3b06d4443b5e08abd20041e17be3981c7d", + "zh:a5a49279474ddb12c8902f33ee9a714800cf0860c10599984a5145e63aec3e8c", + "zh:abbcc76e8ea53a92f701ddb15eb5771306ae19273e5b199b2cb1370723b16d1d", + "zh:ac942c462a313861146df88d0d1223e5b59d7c9acd83066dc0b8f2565067e558", + "zh:c9b00cc290736b956fb2d7f1c478671c8ef25d8a07526c791c9cf218580de1b9", + "zh:f569b65999264a9416862bca5cd2a6177d94ccb0424f3a4ef424428912b9cb3c", + ] +} diff --git a/deployment/modules/composer_env/main.tf b/deployment/modules/composer_env/main.tf index d6abe9f7..2b8ec148 100644 --- a/deployment/modules/composer_env/main.tf +++ b/deployment/modules/composer_env/main.tf @@ -10,7 +10,7 @@ resource "google_composer_environment" "example_environment" { airflow_config_overrides = { # TODO: Update this to allowed_deserialization_classes_regexp with Airflow 2.8.1 # https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html#allowed-deserialization-classes-regexp - core-allowed_deserialization_classes = ".*" + core-allowed_deserialization_classes_regexp = ".*" scheduler-min_file_process_interval = "120" } # Note: keep this in sync with .github/requirements.txt diff --git a/deployment/provider.tf b/deployment/provider.tf index 1b0b3da8..eeec6c48 100644 --- a/deployment/provider.tf +++ b/deployment/provider.tf @@ -5,7 +5,7 @@ provider "google-beta" { terraform { backend "gcs" { - bucket = "composer-ml-auto-solutions-tfstate" + bucket = "us-central1-akshu-test-225d2657-bucket" prefix = "terraform/state" } } From 9dcadc8cadcac3012d90d74153ec6124e021f989 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 25 Oct 2024 22:56:14 +0000 Subject: [PATCH 02/24] fix () --- dags/multipod/configs/jax_tests_gce_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index 159ee69b..d956a172 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -41,7 +41,7 @@ def get_jax_distributed_initialize_config( num_slices: int = 1, base_output_dir: str = gcs_bucket.BASE_OUTPUT_DIR, ): - set_up_cmds = 'pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; python -m pip install fabric;' + set_up_cmds = ('pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; python -m pip install fabric;') run_model_cmds = ( ( "python -m 'print(10111); import jax; jax.distributed.initialize(); print(20111);'" From 08e6f6cfff04ea3aa3f137a65c79caa9026add58 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 25 Oct 2024 23:03:22 +0000 Subject: [PATCH 03/24] Fix the setup command --- dags/multipod/configs/jax_tests_gce_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index d956a172..86f1a0c9 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -41,7 +41,10 @@ def get_jax_distributed_initialize_config( num_slices: int = 1, base_output_dir: str = gcs_bucket.BASE_OUTPUT_DIR, ): - set_up_cmds = ('pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; python -m pip install fabric;') + set_up_cmds = ( + 'pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html, + 'python -m pip install fabric;', + ) run_model_cmds = ( ( "python -m 'print(10111); import jax; jax.distributed.initialize(); print(20111);'" From 0d4a54cbc9589ba7474364caa8033c1c09e200a7 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 1 Nov 2024 20:54:48 +0000 Subject: [PATCH 04/24] Fixes --- dags/multipod/configs/jax_tests_gce_config.py | 23 +++---- dags/multipod/jax_functional_tests.py | 64 +++++++++---------- deployment/cloud_composer.auto.tfvars | 2 +- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index 86f1a0c9..d650e356 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -22,8 +22,7 @@ PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value -# PROJECT_NAME = Project.TPU_PROD_ENV_MULTIPOD.value -# RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value + def get_jax_distributed_initialize_config( tpu_version: TpuVersion, @@ -37,19 +36,17 @@ def get_jax_distributed_initialize_config( network: str = "default", subnetwork: str = "default", is_tpu_reserved: bool = True, - automated_test: bool = True, num_slices: int = 1, - base_output_dir: str = gcs_bucket.BASE_OUTPUT_DIR, ): - set_up_cmds = ( - 'pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html, - 'python -m pip install fabric;', - ) - run_model_cmds = ( - ( - "python -m 'print(10111); import jax; jax.distributed.initialize(); print(20111);'" - ), - ) # TODO: Do we need to verify success or will the command fail if execution fails? + test_platform = common.Platform.GCE + set_up_cmds = common.setup_maxtext(test_mode, test_platform) + set_up_cmds = [ + "pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html", + ] + run_model_cmds = [ + "python3 -c 'import jax; jax.distributed.initialize()'", + ] + job_test_config = test_config.TpuVmTest( test_config.Tpu( version=tpu_version, diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index ddd5147f..7a3fbbfb 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Zone +from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion from dags.multipod.configs import jax_tests_gce_config from dags.multipod.configs.common import SetupMode @@ -57,36 +57,36 @@ ) # v5p - # v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value - # v5p_network = V5_NETWORKS - # v5p_subnetwork = V5P_SUBNETWORKS - # v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value + v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value + v5p_network = V5_NETWORKS + v5p_subnetwork = V5P_SUBNETWORKS + v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value - # jax_nightly_1slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - # tpu_version=TpuVersion.V5P, - # tpu_cores=8, - # tpu_zone=Zone.US_EAST5_A.value, - # runtime_version=v5p_runtime_version, - # project_name=v5p_project_name, - # time_out_in_min=60, - # is_tpu_reserved=True, - # test_name=default_test_name, - # test_mode=test_mode, - # network=v5p_network, - # subnetwork=v5p_subnetwork, - # ) + jax_nightly_1slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=default_test_name, + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) - # jax_nightly_2slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - # tpu_version=TpuVersion.V5P, - # tpu_cores=8, - # num_slices=2, - # tpu_zone=Zone.US_EAST5_A.value, - # runtime_version=v5p_runtime_version, - # project_name=v5p_project_name, - # time_out_in_min=60, - # is_tpu_reserved=True, - # test_name=default_test_name, - # test_mode=test_mode, - # network=v5p_network, - # subnetwork=v5p_subnetwork, - # ) + jax_nightly_2slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + num_slices=2, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=default_test_name, + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) diff --git a/deployment/cloud_composer.auto.tfvars b/deployment/cloud_composer.auto.tfvars index c7cfa022..2fed2f17 100644 --- a/deployment/cloud_composer.auto.tfvars +++ b/deployment/cloud_composer.auto.tfvars @@ -6,7 +6,7 @@ project_config = { environment_config = [ { - environment_name = "akshu-test" + environment_name = "ml-automation-solutions-dev" service_account_id = "akshu-dev" } ] From 94eb764d71ebb4257c7a46fe5744f193684683c7 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 1 Nov 2024 21:08:25 +0000 Subject: [PATCH 05/24] Schedule the test --- dags/multipod/jax_functional_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index 7a3fbbfb..c0801f33 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -21,8 +21,8 @@ from dags.multipod.configs import jax_tests_gce_config from dags.multipod.configs.common import SetupMode -# Run once a day at am UTC ( am PST) # TODO time -SCHEDULED_TIME = "0 9 * * *" if composer_env.is_prod_env() else None # TODO time +# Run once a day at 10 am UTC (2 am PST) +SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None with models.DAG( dag_id="jax_functional_tests", From aa1787c5cb98589866dd4d4b9dceb4517beecb50 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:29:05 -0700 Subject: [PATCH 06/24] Delete dags/multipod/configs/jax_tests_gke_config.py --- dags/multipod/configs/jax_tests_gke_config.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 dags/multipod/configs/jax_tests_gke_config.py diff --git a/dags/multipod/configs/jax_tests_gke_config.py b/dags/multipod/configs/jax_tests_gke_config.py deleted file mode 100644 index e69de29b..00000000 From c20a099ad4dd61b25c9b8e9a5bd81ba04363e2cc Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:30:23 -0700 Subject: [PATCH 07/24] Update cloud_composer.auto.tfvars --- deployment/cloud_composer.auto.tfvars | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deployment/cloud_composer.auto.tfvars b/deployment/cloud_composer.auto.tfvars index 2fed2f17..d86d4d6a 100644 --- a/deployment/cloud_composer.auto.tfvars +++ b/deployment/cloud_composer.auto.tfvars @@ -7,6 +7,6 @@ project_config = { environment_config = [ { environment_name = "ml-automation-solutions-dev" - service_account_id = "akshu-dev" + service_account_id = "ml-auto-solutions-dev" } ] From 45b8bcfdc1a1e0977440717a93df217b9b1ebbca Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:31:16 -0700 Subject: [PATCH 08/24] Delete deployment/deployment_composer_env/.terraform.lock.hcl --- .../.terraform.lock.hcl | 21 ------------------- 1 file changed, 21 deletions(-) delete mode 100644 deployment/deployment_composer_env/.terraform.lock.hcl diff --git a/deployment/deployment_composer_env/.terraform.lock.hcl b/deployment/deployment_composer_env/.terraform.lock.hcl deleted file mode 100644 index 79b9ee58..00000000 --- a/deployment/deployment_composer_env/.terraform.lock.hcl +++ /dev/null @@ -1,21 +0,0 @@ -# This file is maintained automatically by "terraform init". -# Manual edits may be lost in future updates. - -provider "registry.terraform.io/hashicorp/google-beta" { - version = "6.7.0" - hashes = [ - "h1:w5bxwp3tSvAViwW/14MyjaWXbA3bdSdx0nxnNv0OOEw=", - "zh:0def181a8781c13f002bf77afb83ce600bd5c08e4324f5d2c2e4d60a8b41dd5c", - "zh:323453165fa8c69a4400c0c14750bce9ac9729f2313966e983a13aa8a0d68f9e", - "zh:5507893e7115d0702de0a87f583d09639e082c825ec1e5bd1e4317061d5c7b04", - "zh:74505c85aaa9ef70491bdb708e0aee459fb1b85848d170217a6bf0c4cd2edf5c", - "zh:81d7f41eda8547e140304320fa0d188be26982234831c8405db5a91aef20e478", - "zh:95cd1322406f7cd97e6e4c0ddedf60060fa70481c9134871148ef6dd71d2254d", - "zh:a4ba2b745d35a7e09d318f30f0057c3b06d4443b5e08abd20041e17be3981c7d", - "zh:a5a49279474ddb12c8902f33ee9a714800cf0860c10599984a5145e63aec3e8c", - "zh:abbcc76e8ea53a92f701ddb15eb5771306ae19273e5b199b2cb1370723b16d1d", - "zh:ac942c462a313861146df88d0d1223e5b59d7c9acd83066dc0b8f2565067e558", - "zh:c9b00cc290736b956fb2d7f1c478671c8ef25d8a07526c791c9cf218580de1b9", - "zh:f569b65999264a9416862bca5cd2a6177d94ccb0424f3a4ef424428912b9cb3c", - ] -} From 411ab31b5f87a581b43df8bddba8c9493cd28778 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:32:27 -0700 Subject: [PATCH 09/24] Update main.tf --- deployment/modules/composer_env/main.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deployment/modules/composer_env/main.tf b/deployment/modules/composer_env/main.tf index 2b8ec148..d6abe9f7 100644 --- a/deployment/modules/composer_env/main.tf +++ b/deployment/modules/composer_env/main.tf @@ -10,7 +10,7 @@ resource "google_composer_environment" "example_environment" { airflow_config_overrides = { # TODO: Update this to allowed_deserialization_classes_regexp with Airflow 2.8.1 # https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html#allowed-deserialization-classes-regexp - core-allowed_deserialization_classes_regexp = ".*" + core-allowed_deserialization_classes = ".*" scheduler-min_file_process_interval = "120" } # Note: keep this in sync with .github/requirements.txt From 1786a74407760187815ff911e63469aeb4cd7938 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:32:53 -0700 Subject: [PATCH 10/24] Update cloud_composer.auto.tfvars From 41dfea2a0b62381389e3e257cf0d5183c0597421 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:33:23 -0700 Subject: [PATCH 11/24] Update provider.tf --- deployment/provider.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deployment/provider.tf b/deployment/provider.tf index eeec6c48..1b0b3da8 100644 --- a/deployment/provider.tf +++ b/deployment/provider.tf @@ -5,7 +5,7 @@ provider "google-beta" { terraform { backend "gcs" { - bucket = "us-central1-akshu-test-225d2657-bucket" + bucket = "composer-ml-auto-solutions-tfstate" prefix = "terraform/state" } } From 622fb0dd2497db693e9db26928368bc7c8c4db13 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:33:38 -0700 Subject: [PATCH 12/24] Update cloud_composer.auto.tfvars From 821a06e8c22be0b3b1038fd120a288d42efbbe4e Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:34:43 -0700 Subject: [PATCH 13/24] Update cloud_composer.auto.tfvars --- deployment/cloud_composer.auto.tfvars | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deployment/cloud_composer.auto.tfvars b/deployment/cloud_composer.auto.tfvars index d86d4d6a..7fdeef8b 100644 --- a/deployment/cloud_composer.auto.tfvars +++ b/deployment/cloud_composer.auto.tfvars @@ -5,6 +5,10 @@ project_config = { } environment_config = [ + { + environment_name = "ml-automation-solutions" + service_account_id = "ml-auto-solutions" + }, { environment_name = "ml-automation-solutions-dev" service_account_id = "ml-auto-solutions-dev" From edf12d0efe65d5c6d350656e9c53c16936cc5394 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:39:21 -0700 Subject: [PATCH 14/24] Update jax_tests_gce_config.py --- dags/multipod/configs/jax_tests_gce_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index d650e356..46dba4a9 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -41,10 +41,10 @@ def get_jax_distributed_initialize_config( test_platform = common.Platform.GCE set_up_cmds = common.setup_maxtext(test_mode, test_platform) set_up_cmds = [ - "pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html", + "pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html", ] run_model_cmds = [ - "python3 -c 'import jax; jax.distributed.initialize()'", + "python3 -c 'import jax; jax.distributed.initialize()'", ] job_test_config = test_config.TpuVmTest( @@ -67,7 +67,7 @@ def get_jax_distributed_initialize_config( job_gcp_config = gcp_config.GCPConfig( project_name=project_name, zone=tpu_zone, - dataset_name=metric_config.DatasetOption.XLML_DATASET, # TODO: can remove? + dataset_name=metric_config.DatasetOption.XLML_DATASET, ) return task.run_queued_resource_test( From 973f6279677dd6b82422a895b89049ca59e78df2 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:40:38 -0700 Subject: [PATCH 15/24] Update jax_functional_tests.py --- dags/multipod/jax_functional_tests.py | 96 ++++++++++++++------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index c0801f33..b7ab7198 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -35,26 +35,28 @@ test_mode = SetupMode.NIGHTLY # v4 - jax_nightly_1slice_v4_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V4, - tpu_cores=8, - tpu_zone=Zone.US_CENTRAL2_B.value, - time_out_in_min=60, - is_tpu_reserved=False, - test_name=default_test_name, - test_mode=test_mode, - ) + jax_nightly_1slice_v4_8 = + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + test_name=default_test_name, + test_mode=test_mode, + ) - jax_nightly_2slice_v4_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V4, - tpu_cores=8, - tpu_zone=Zone.US_CENTRAL2_B.value, - time_out_in_min=60, - is_tpu_reserved=False, - num_slices=2, - test_name=default_test_name, - test_mode=test_mode, - ) + jax_nightly_2slice_v4_8 = + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=2, + test_name=default_test_name, + test_mode=test_mode, + ) # v5p v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value @@ -62,31 +64,33 @@ v5p_subnetwork = V5P_SUBNETWORKS v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value - jax_nightly_1slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V5P, - tpu_cores=8, - tpu_zone=Zone.US_EAST5_A.value, - runtime_version=v5p_runtime_version, - project_name=v5p_project_name, - time_out_in_min=60, - is_tpu_reserved=True, - test_name=default_test_name, - test_mode=test_mode, - network=v5p_network, - subnetwork=v5p_subnetwork, - ) + jax_nightly_1slice_v5p_8 = + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=default_test_name, + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) - jax_nightly_2slice_v5p_8 = jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V5P, - tpu_cores=8, - num_slices=2, - tpu_zone=Zone.US_EAST5_A.value, - runtime_version=v5p_runtime_version, - project_name=v5p_project_name, - time_out_in_min=60, - is_tpu_reserved=True, - test_name=default_test_name, - test_mode=test_mode, - network=v5p_network, - subnetwork=v5p_subnetwork, - ) + jax_nightly_2slice_v5p_8 = + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + num_slices=2, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=default_test_name, + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) From b33e633e82810508675b82668ab65ac14fe3f8f7 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 14:43:24 -0700 Subject: [PATCH 16/24] Update jax_functional_tests.py --- dags/multipod/jax_functional_tests.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index b7ab7198..e3e6d91f 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -35,7 +35,7 @@ test_mode = SetupMode.NIGHTLY # v4 - jax_nightly_1slice_v4_8 = + jax_nightly_1slice_v4_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( tpu_version=TpuVersion.V4, tpu_cores=8, @@ -44,9 +44,10 @@ is_tpu_reserved=False, test_name=default_test_name, test_mode=test_mode, - ) + ) + ) - jax_nightly_2slice_v4_8 = + jax_nightly_2slice_v4_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( tpu_version=TpuVersion.V4, tpu_cores=8, @@ -57,6 +58,7 @@ test_name=default_test_name, test_mode=test_mode, ) + ) # v5p v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value @@ -64,7 +66,7 @@ v5p_subnetwork = V5P_SUBNETWORKS v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value - jax_nightly_1slice_v5p_8 = + jax_nightly_1slice_v5p_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( tpu_version=TpuVersion.V5P, tpu_cores=8, @@ -78,8 +80,9 @@ network=v5p_network, subnetwork=v5p_subnetwork, ) + ) - jax_nightly_2slice_v5p_8 = + jax_nightly_2slice_v5p_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( tpu_version=TpuVersion.V5P, tpu_cores=8, @@ -94,3 +97,4 @@ network=v5p_network, subnetwork=v5p_subnetwork, ) + ) From 52e6ea2488d2d638a706ae6640cf531661333d15 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 16:10:14 -0700 Subject: [PATCH 17/24] Update cloud_composer.auto.tfvars From 37349530d59a9e5cc33db6128f025aad1dcff7f2 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 16:12:54 -0700 Subject: [PATCH 18/24] Update cloud_composer.auto.tfvars From 6654ebac71f0225aeff4a1d5db1333e6799f84a9 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 16:14:42 -0700 Subject: [PATCH 19/24] Update cloud_composer.auto.tfvars From 5342b700bcc12f2c95f7212169c9b3eaca095454 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Fri, 1 Nov 2024 16:15:28 -0700 Subject: [PATCH 20/24] Update cloud_composer.auto.tfvars From ca0df82b3e98bbec4974907c023b2bd62fd54668 Mon Sep 17 00:00:00 2001 From: Akanksha Date: Mon, 4 Nov 2024 14:55:12 -0800 Subject: [PATCH 21/24] Update jax_tests_gce_config.py --- dags/multipod/configs/jax_tests_gce_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index 46dba4a9..97ac97f0 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From b084ec1bc713a4f6b13b280a5a288e0d97e5c47e Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Tue, 5 Nov 2024 06:08:00 +0000 Subject: [PATCH 22/24] PR comments --- dags/multipod/configs/jax_tests_gce_config.py | 2 +- dags/multipod/jax_functional_tests.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index 97ac97f0..0f31af96 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities to construct configs for JAX tests.""" +"""Utilities to construct configs for JAX tests for GCE.""" from xlml.apis import gcp_config, metric_config, task, test_config from dags import test_owner, gcs_bucket diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index e3e6d91f..b92a37ea 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -60,6 +60,8 @@ ) ) + jax_nightly_1slice_v4_8 >> jax_nightly_2slice_v4_8 + # v5p v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value v5p_network = V5_NETWORKS @@ -98,3 +100,5 @@ subnetwork=v5p_subnetwork, ) ) + +jax_nightly_1slice_v5p_8 >> jax_nightly_2slice_v5p_8 From 93d890c96ac01c84f253d9aa750cafaa00d77ecc Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 8 Nov 2024 19:30:13 +0000 Subject: [PATCH 23/24] Run the GCE JAX test on nightly, stable and stable stack --- dags/multipod/jax_functional_tests.py | 129 +++++++++++++------------- 1 file changed, 65 insertions(+), 64 deletions(-) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index b92a37ea..892149ba 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -31,74 +31,75 @@ start_date=datetime.datetime(2024, 10, 23), catchup=False, ) as dag: - default_test_name = "jax" - test_mode = SetupMode.NIGHTLY + default_test_name = "jax-distributed-initialize" + test_modes = [SetupMode.STABLE, SetupMode.NIGHTLY, SetupMode.JAX_STABLE_STACK] - # v4 - jax_nightly_1slice_v4_8 = ( - jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V4, - tpu_cores=8, - tpu_zone=Zone.US_CENTRAL2_B.value, - time_out_in_min=60, - is_tpu_reserved=False, - test_name=default_test_name, - test_mode=test_mode, - ) - ) + for test_mode in test_modes: + # v4 + jax_nightly_1slice_v4_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + test_name=f"{default_test_name}-1xv4-8-{test_mode.value}", + test_mode=test_mode, + ) + ) - jax_nightly_2slice_v4_8 = ( - jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V4, - tpu_cores=8, - tpu_zone=Zone.US_CENTRAL2_B.value, - time_out_in_min=60, - is_tpu_reserved=False, - num_slices=2, - test_name=default_test_name, - test_mode=test_mode, - ) - ) + jax_nightly_2slice_v4_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=2, + test_name=f"{default_test_name}-2xv4-8-{test_mode.value}", + test_mode=test_mode, + ) + ) - jax_nightly_1slice_v4_8 >> jax_nightly_2slice_v4_8 + jax_nightly_1slice_v4_8 >> jax_nightly_2slice_v4_8 - # v5p - v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value - v5p_network = V5_NETWORKS - v5p_subnetwork = V5P_SUBNETWORKS - v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value + # v5p + v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value + v5p_network = V5_NETWORKS + v5p_subnetwork = V5P_SUBNETWORKS + v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value - jax_nightly_1slice_v5p_8 = ( - jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V5P, - tpu_cores=8, - tpu_zone=Zone.US_EAST5_A.value, - runtime_version=v5p_runtime_version, - project_name=v5p_project_name, - time_out_in_min=60, - is_tpu_reserved=True, - test_name=default_test_name, - test_mode=test_mode, - network=v5p_network, - subnetwork=v5p_subnetwork, - ) - ) + jax_nightly_1slice_v5p_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=f"{default_test_name}-1xv5p-8-{test_mode.value}", + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) + ) - jax_nightly_2slice_v5p_8 = ( - jax_tests_gce_config.get_jax_distributed_initialize_config( - tpu_version=TpuVersion.V5P, - tpu_cores=8, - num_slices=2, - tpu_zone=Zone.US_EAST5_A.value, - runtime_version=v5p_runtime_version, - project_name=v5p_project_name, - time_out_in_min=60, - is_tpu_reserved=True, - test_name=default_test_name, - test_mode=test_mode, - network=v5p_network, - subnetwork=v5p_subnetwork, - ) - ) + jax_nightly_2slice_v5p_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + num_slices=2, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=f"{default_test_name}-2xv5p-8-{test_mode.value}", + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) + ) -jax_nightly_1slice_v5p_8 >> jax_nightly_2slice_v5p_8 + jax_nightly_1slice_v5p_8 >> jax_nightly_2slice_v5p_8 From 6ffbad3aa183f7fb86a4331e1b865cda9fe9f2da Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Fri, 8 Nov 2024 20:09:28 +0000 Subject: [PATCH 24/24] Run the GCE JAX test on nightly, stable and stable stack --- dags/multipod/jax_functional_tests.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index 892149ba..baa53dc9 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -33,6 +33,7 @@ ) as dag: default_test_name = "jax-distributed-initialize" test_modes = [SetupMode.STABLE, SetupMode.NIGHTLY, SetupMode.JAX_STABLE_STACK] + v4_task_arr, v5p_task_arr = [], [] for test_mode in test_modes: # v4 @@ -43,10 +44,14 @@ tpu_zone=Zone.US_CENTRAL2_B.value, time_out_in_min=60, is_tpu_reserved=False, - test_name=f"{default_test_name}-1xv4-8-{test_mode.value}", + test_name=f"{default_test_name}-{test_mode.value}", test_mode=test_mode, ) ) + if len(v4_task_arr) > 1: + # pylint: disable-next=pointless-statement + v4_task_arr[-1] >> jax_nightly_1slice_v4_8 + v4_task_arr.append(jax_nightly_1slice_v4_8) jax_nightly_2slice_v4_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( @@ -56,12 +61,14 @@ time_out_in_min=60, is_tpu_reserved=False, num_slices=2, - test_name=f"{default_test_name}-2xv4-8-{test_mode.value}", + test_name=f"{default_test_name}-{test_mode.value}", test_mode=test_mode, ) ) - jax_nightly_1slice_v4_8 >> jax_nightly_2slice_v4_8 + # pylint: disable-next=pointless-statement + v4_task_arr[-1] >> jax_nightly_2slice_v4_8 + v4_task_arr.append(jax_nightly_2slice_v4_8) # v5p v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value @@ -78,12 +85,16 @@ project_name=v5p_project_name, time_out_in_min=60, is_tpu_reserved=True, - test_name=f"{default_test_name}-1xv5p-8-{test_mode.value}", + test_name=f"{default_test_name}-{test_mode.value}", test_mode=test_mode, network=v5p_network, subnetwork=v5p_subnetwork, ) ) + if len(v5p_task_arr) > 1: + # pylint: disable-next=pointless-statement + v5p_task_arr[-1] >> jax_nightly_1slice_v5p_8 + v5p_task_arr.append(jax_nightly_1slice_v5p_8) jax_nightly_2slice_v5p_8 = ( jax_tests_gce_config.get_jax_distributed_initialize_config( @@ -95,11 +106,13 @@ project_name=v5p_project_name, time_out_in_min=60, is_tpu_reserved=True, - test_name=f"{default_test_name}-2xv5p-8-{test_mode.value}", + test_name=f"{default_test_name}-{test_mode.value}", test_mode=test_mode, network=v5p_network, subnetwork=v5p_subnetwork, ) ) - jax_nightly_1slice_v5p_8 >> jax_nightly_2slice_v5p_8 + # pylint: disable-next=pointless-statement + v5p_task_arr[-1] >> jax_nightly_2slice_v5p_8 + v5p_task_arr.append(jax_nightly_2slice_v5p_8)