diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py new file mode 100644 index 00000000..0f31af96 --- /dev/null +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -0,0 +1,76 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for JAX tests for GCE.""" + +from xlml.apis import gcp_config, metric_config, task, test_config +from dags import test_owner, gcs_bucket +from dags.multipod.configs import common +from dags.vm_resource import TpuVersion, Project, RuntimeVersion +import datetime + +PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value +RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value + + +def get_jax_distributed_initialize_config( + tpu_version: TpuVersion, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + test_name: str, + test_mode: common.SetupMode, + project_name: str = PROJECT_NAME, + runtime_version: str = RUNTIME_IMAGE, + network: str = "default", + subnetwork: str = "default", + is_tpu_reserved: bool = True, + num_slices: int = 1, +): + test_platform = common.Platform.GCE + set_up_cmds = common.setup_maxtext(test_mode, test_platform) + set_up_cmds = [ + "pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html", + ] + run_model_cmds = [ + "python3 -c 'import jax; jax.distributed.initialize()'", + ] + + job_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + runtime_version=runtime_version, + reserved=is_tpu_reserved, + network=network, + subnetwork=subnetwork, + ), + test_name=test_name, + set_up_cmds=set_up_cmds, + run_model_cmds=run_model_cmds, + timeout=datetime.timedelta(minutes=time_out_in_min), + task_owner=test_owner.AKANKSHA_G, + num_slices=num_slices, + ) + + job_gcp_config = gcp_config.GCPConfig( + project_name=project_name, + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, + ) + + return task.run_queued_resource_test( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py new file mode 100644 index 00000000..baa53dc9 --- /dev/null +++ b/dags/multipod/jax_functional_tests.py @@ -0,0 +1,118 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run JAX tests.""" + +import datetime +from airflow import models +from dags import composer_env +from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion +from dags.multipod.configs import jax_tests_gce_config +from dags.multipod.configs.common import SetupMode + +# Run once a day at 10 am UTC (2 am PST) +SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None + +with models.DAG( + dag_id="jax_functional_tests", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "jax"], + start_date=datetime.datetime(2024, 10, 23), + catchup=False, +) as dag: + default_test_name = "jax-distributed-initialize" + test_modes = [SetupMode.STABLE, SetupMode.NIGHTLY, SetupMode.JAX_STABLE_STACK] + v4_task_arr, v5p_task_arr = [], [] + + for test_mode in test_modes: + # v4 + jax_nightly_1slice_v4_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + test_name=f"{default_test_name}-{test_mode.value}", + test_mode=test_mode, + ) + ) + if len(v4_task_arr) > 1: + # pylint: disable-next=pointless-statement + v4_task_arr[-1] >> jax_nightly_1slice_v4_8 + v4_task_arr.append(jax_nightly_1slice_v4_8) + + jax_nightly_2slice_v4_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=2, + test_name=f"{default_test_name}-{test_mode.value}", + test_mode=test_mode, + ) + ) + + # pylint: disable-next=pointless-statement + v4_task_arr[-1] >> jax_nightly_2slice_v4_8 + v4_task_arr.append(jax_nightly_2slice_v4_8) + + # v5p + v5p_project_name = Project.TPU_PROD_ENV_AUTOMATED.value + v5p_network = V5_NETWORKS + v5p_subnetwork = V5P_SUBNETWORKS + v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value + + jax_nightly_1slice_v5p_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=f"{default_test_name}-{test_mode.value}", + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) + ) + if len(v5p_task_arr) > 1: + # pylint: disable-next=pointless-statement + v5p_task_arr[-1] >> jax_nightly_1slice_v5p_8 + v5p_task_arr.append(jax_nightly_1slice_v5p_8) + + jax_nightly_2slice_v5p_8 = ( + jax_tests_gce_config.get_jax_distributed_initialize_config( + tpu_version=TpuVersion.V5P, + tpu_cores=8, + num_slices=2, + tpu_zone=Zone.US_EAST5_A.value, + runtime_version=v5p_runtime_version, + project_name=v5p_project_name, + time_out_in_min=60, + is_tpu_reserved=True, + test_name=f"{default_test_name}-{test_mode.value}", + test_mode=test_mode, + network=v5p_network, + subnetwork=v5p_subnetwork, + ) + ) + + # pylint: disable-next=pointless-statement + v5p_task_arr[-1] >> jax_nightly_2slice_v5p_8 + v5p_task_arr.append(jax_nightly_2slice_v5p_8) diff --git a/dags/test_owner.py b/dags/test_owner.py index dcfccc9d..3e9d8d91 100644 --- a/dags/test_owner.py +++ b/dags/test_owner.py @@ -66,3 +66,6 @@ class Team(enum.Enum): # FRAMEWORK QINY_Y = "Qinyi Y." + +# JAX +AKANKSHA_G = "Akanksha G." diff --git a/deployment/cloud_composer.auto.tfvars b/deployment/cloud_composer.auto.tfvars index 0a307c21..7fdeef8b 100644 --- a/deployment/cloud_composer.auto.tfvars +++ b/deployment/cloud_composer.auto.tfvars @@ -13,4 +13,4 @@ environment_config = [ environment_name = "ml-automation-solutions-dev" service_account_id = "ml-auto-solutions-dev" } -] \ No newline at end of file +]