From e9d83410a81458b15b5ea3c04e8039216af9e055 Mon Sep 17 00:00:00 2001 From: Jesus Orozco Date: Thu, 12 Dec 2024 21:19:51 +0000 Subject: [PATCH] Refactor jobset to align with new pathways structure --- axlearn/cloud/gcp/job.py | 328 ++++++++++++++++++++++----------------- 1 file changed, 189 insertions(+), 139 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index c662b6c2..69c1a25c 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -6,7 +6,6 @@ """ import atexit import importlib -import importlib import io import logging import math @@ -505,7 +504,10 @@ def _is_pathways_used(self) -> bool: # identify if a job is configured to use pathways by # checking jax_backend flag and optional import for pathways utils # brittle implementation - return "pathwaysutils" in self.config.import_modules and "jax_backend proxy" in self.config.command + return ( + "pathwaysutils" in self.config.import_modules + and "jax_backend proxy" in self.config.command + ) def _import_modules(self): try: @@ -576,82 +578,40 @@ def _build_container(self, job_type: str = None) -> Nested[Any]: if self.using_pathways: container_name = f"{cfg.name}-{job_type}" - volume_mounts.append( - dict( - name="shared-tmp", - mountPath="/tmp", - ), - ) staging_location = f"{cfg.output_dir}/pathways-staging/tmp" - cluster = flags.FLAGS.cluster or gcp_settings("gke_cluster", required=False, fv=flags.FLAGS) - pathways_port = 38677 - rm_address = f"{cfg.name}-rm-0-0.{cfg.name}.default.svc.{cluster}-domain:{pathways_port}" - #rm_address = f"{cfg.name}-rm-0-0.{cfg.name}:{pathways_port}" + # cluster = flags.FLAGS.cluster or gcp_settings("gke_cluster", required=False, fv=flags.FLAGS) + # pathways_proxy_port, pathways_rm_port = 38676, 38677 + # rm_address = f"{cfg.name}-rm-0-0.{cfg.name}.default.svc.{cluster}-domain:{pathways_port}" + # rm_address = f"{cfg.name}-rm-0-0.{cfg.name}:{pathways_port}" env_vars.update( + # dump XLA flags to GCS bucket for troubleshooting purposes. XLA_FLAGS=f"--xla_dump_to={cfg.output_dir}/{cfg.name}/xla/" ) - if job_type == "worker": - args.extend( - [ - f"--server_port={pathways_port}", - f"--resource_manager_address={rm_address}", - f"--gcs_scratch_location={staging_location}", - ] - ) - image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" - ports.append(dict(containerPort=pathways_port)) - resources = {"limits": {"google.com/tpu": system.chips_per_vm}} - - elif job_type == "rm": - tpu_type = self._get_pathways_tpu_type(system.device_type) - args.extend( - [ - f"--server_port={pathways_port}", - "--node_type=resource_manager", - f"--gcs_scratch_location={staging_location}", - f"--instance_count={cfg.accelerator.num_replicas}", - f"--instance_type={tpu_type}:{system.topology}", - ] - ) + if job_type == "pathways-head": + # resources["limits"]["memory"] = "100Gi" + # resources["limits"]["cpu"] = "24" env_vars.update( + JAX_BACKEND_TARGET="grpc://$(HOST_ADDRESS):38681", + XCLOUD_ENVIRONMENT="GCP", + JAX_PLATFORMS="proxy", + ENABLE_PATHWAYS_PERSISTENCE="true", TPU_SKIP_MDS_QUERY="true", - HOST_ADDRESS=f"{cfg.name}-{job_type}-0-0.{cfg.name}", - #HOST_ADDRESS="$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)", - REPLICATED_JOB_NAME=job_type, - JOBSET_NAME=cfg.name, ) - image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" - resources["limits"]["memory"] = "8Gi" - resources["limits"]["cpu"] = "4" - ports.append(dict(containerPort=pathways_port)) - - elif job_type == "proxy": + elif job_type == "pathways-workers": + env_vars.update( + MEGASCALE_COORDINATOR_ADDRESS="$(PATHWAYS_HEAD)", + ) args.extend( [ - f"--server_port={pathways_port - 1}", - f"--resource_manager_address={rm_address}", + "--server_port=38679", + "--resource_manager_address=$(PATHWAYS_HEAD):38677", f"--gcs_scratch_location={staging_location}", ] ) - image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest" - resources["limits"]["memory"] = "100Gi" - resources["limits"]["cpu"] = "24" - ports.append(dict(containerPort=pathways_port-1)) - - elif job_type == "user": - resources["limits"]["memory"] = "100Gi" - resources["limits"]["cpu"] = "24" - proxy = ( - f"grpc://{cfg.name}-proxy-0-0.{cfg.name}.default.svc.{cluster}-domain:{pathways_port - 1}" - #f"grpc://{cfg.name}-proxy-0-0.{cfg.name}:{pathways_port - 1}" - ) - env_vars.update( - JAX_BACKEND_TARGET=proxy, - XCLOUD_ENVIRONMENT="GCP", - JOBSET_NAME=cfg.name, - JAX_PLATFORMS="proxy", - ) + image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" + ports.append(dict(containerPort=38679)) + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()] k8s_env_vars.append( @@ -674,6 +634,46 @@ def _build_container(self, job_type: str = None) -> Nested[Any]: }, }, ) + if job_type == "pathways-head": + k8s_env_vars.append( + { + "name": "HOST_ADDRESS", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/coordinator']", + } + }, + }, + ) + if job_type == "pathways-workers": + k8s_env_vars.extend( + [ + { + "name": "PATHWAYS_HEAD", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/coordinator']", + } + }, + }, + { + "name": "MEGASCALE_NUM_SLICES", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']", + } + }, + }, + { + "name": "MEGASCALE_SLICE_ID", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/job-index']", + } + }, + }, + ] + ) return dict( name=container_name, @@ -683,9 +683,7 @@ def _build_container(self, job_type: str = None) -> Nested[Any]: ports=ports, securityContext=dict(privileged=True), # TODO(markblee): Improve SIGTERM behavior for command. - command=["bash", "-c", cfg.command] - if not self.using_pathways or job_type == "user" - else None, + command=["bash", "-c", cfg.command] if job_type != "pathways-workers" else None, resources=resources, # Env var values should always be strings. env=k8s_env_vars, @@ -729,9 +727,75 @@ def _build_uploader_container(self) -> Nested[Any]: args=[sync_command], resources=resources, volumeMounts=volume_mounts, - # args=args, ) + def _build_pathways_containers(self) -> list[dict]: + """Builds a config for the pathways containers which orchestrate resource management + and pathways proxy communications. + + Returns: + A nested dict corresponding to a k8s Container config. + """ + cfg: TPUGKEJob.Config = self.config + system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] + staging_location = f"{cfg.output_dir}/pathways-staging/tmp" + tpu_type = self._get_pathways_tpu_type(system.device_type) + + return [ + dict( + name="pathways-proxy", + image="us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest", + # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers + # SideCar container is an init container with restartPolicy as "Always". + restartPolicy="Always", + env=[ + { + "name": "PATHWAYS_HEAD", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/coordinator']", + } + }, + } + ], + args=[ + "--resource_manager_address=$(PATHWAYS_HEAD):38677", + "--server_port=38681", + f"--gcs_scratch_location={staging_location}", + ], + ports=[dict(containerPort=38681)], + # resources={"limits": {"cpu": "24", "memory": "100Gi"},}, + ), + dict( + name="pathways-rm", + image="us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest", + # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers + # SideCar container is an init container with restartPolicy as "Always". + env=[ + { + "name": "HOST_ADDRESS", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['jobset.sigs.k8s.io/coordinator']", + } + }, + }, + { + "name": "TPU_SKIP_MDS_QUERY", + "value": "true", + }, + ], + args=[ + "--server_port=38677", + "--node_type=resource_manager", + f"--instance_count={cfg.accelerator.num_replicas}", + f"--instance_type={tpu_type}:{system.topology}", + f"--gcs_scratch_location={staging_location}", + ], + # resources={"limits": {"cpu": "4", "memory": "8Gi"},}, + ), + ] + def _build_pod(self, job_type: str = None) -> Nested[Any]: """Builds a config for a single Pod, which is a set of containers. @@ -743,6 +807,7 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]: cfg: TPUGKEJob.Config = self.config system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] annotations, labels, selector, volumes, tolerations = {}, {}, {}, [], [] + initContainers = [self._build_uploader_container()] volumes.append(dict(name="shared-output", emptyDir={})) if cfg.gcsfuse_mount: @@ -787,7 +852,7 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]: tier = os.environ.get("BASTION_TIER", None) # skip reservation/spot flags for Pathways CPU jobs. - if job_type not in ("rm", "proxy", "user"): + if job_type != "pathways-head": if tier == "0" and cfg.reservation is not None: logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation) selector.update({"cloud.google.com/reservation-name": cfg.reservation}) @@ -875,19 +940,9 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]: } ) - if self.using_pathways: - volumes.append( - dict( - hostPath=dict( - path="/tmp", - type="DirectoryOrCreate", - ), - name="shared-tmp", - ) - ) - - if job_type in ("rm", "proxy", "user"): - selector.update({"cloud.google.com/gke-nodepool": f"cpu-{job_type}-np"}) + if job_type == "pathways-head": + # selector.update({"pathways-head": "true"}) + initContainers.extend(self._build_pathways_containers()) else: selector.update( { @@ -895,7 +950,7 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]: "cloud.google.com/gke-tpu-topology": system.topology, } ) - + # Hardcode metadata.google.internal ip address to avoid transient DNS resolution issue. metadata_host_alias = dict( ip=_METADATA_GOOGLE_INTERNAL_IP, @@ -914,11 +969,11 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]: }, tolerations=tolerations, containers=[self._build_container(job_type)], - # initContainers=[self._build_uploader_container()], + initContainers=initContainers, serviceAccountName=cfg.service_account, volumes=volumes, - hostNetwork=True, - dnsPolicy="ClusterFirstWithHostNet", + hostNetwork=True if self.using_pathways else False, + dnsPolicy="ClusterFirstWithHostNet" if self.using_pathways else None, ) if cfg.priority_class: @@ -938,39 +993,39 @@ def _build_job(self, job_type: str = None) -> Nested[Any]: A nested dict corresponding to a k8s Job config, including the job metadata and spec. """ system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] + annotations, spec = {}, {} - if job_type == "worker": - return dict( - metadata=dict( - annotations={ - # pylint: disable=line-too-long - "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool" - } - ), - spec=dict( - parallelism=system.vms_per_slice, - completions=system.vms_per_slice, - backoffLimit=system.vms_per_slice * 4, - template=self._build_pod(job_type), - ), + if job_type == "pathways-workers": + annotations.update( + {"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"} ) - elif job_type in ("rm", "proxy", "user"): - return dict( - spec=dict( - parallelism=1, - completions=1, - backoffLimit=0, - template=self._build_pod(job_type), - ), + spec.update( + parallelism=system.vms_per_slice, + completions=system.vms_per_slice, + backoffLimit=system.vms_per_slice * 4, + template=self._build_pod(job_type), ) - - return dict( - spec=dict( + elif job_type == "pathways-head": + annotations.update( + {"alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname"} + ) + spec.update( + parallelism=1, + completions=1, + backoffLimit=0, + template=self._build_pod(job_type), + ) + else: + spec.update( parallelism=system.vms_per_slice, completions=system.vms_per_slice, backoffLimit=0, # Fail the job if any node fails. Retries happen at JobSet level. template=self._build_pod(), - ), + ) + + return dict( + metadata=dict(annotations=annotations), + spec=spec, ) def _build_jobset(self) -> Nested[Any]: @@ -985,16 +1040,6 @@ def _build_jobset(self) -> Nested[Any]: annotations, labels = {}, {} - if not self.using_pathways: - annotations.update( - { - # The exclusive topology annotation will ensure that all Pods will have affinity - # rules added that will ensure that they are fully scheduled on the same - # pod-slice node-pools. - "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", - } - ) - if cfg.queue: if self.using_pathways: labels["kueue.x-k8s.io/queue-name"] = cfg.queue @@ -1016,31 +1061,35 @@ def _build_jobset(self) -> Nested[Any]: if self.using_pathways: logging.info("Building pathways jobset.") spec = dict( + coordinator=dict( + replicatedJob="pathways-head", + jobIndex=0, + podIndex=0, + ), failurePolicy=dict(maxRestarts=cfg.max_tries - 1), - successPolicy=dict(operator="All", targetReplicatedJobs=["user"]), + successPolicy=dict(operator="All", targetReplicatedJobs=["pathways-head"]), replicatedJobs=[ dict( - name="worker", - replicas=cfg.accelerator.num_replicas, - template=self._build_job("worker"), - ), - dict( - name="rm", - replicas=1, - template=self._build_job("rm"), - ), - dict( - name="proxy", + name="pathways-head", replicas=1, - template=self._build_job("proxy"), + template=self._build_job("pathways-head"), ), dict( - name="user", - replicas=1, - template=self._build_job("user"), + name="pathways-workers", + replicas=cfg.accelerator.num_replicas, + template=self._build_job("pathways-workers"), ), ], ) + else: + annotations.update( + { + # The exclusive topology annotation will ensure that all Pods will have affinity + # rules added that will ensure that they are fully scheduled on the same + # pod-slice node-pools. + "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", + } + ) return dict( metadata=dict( @@ -1069,6 +1118,7 @@ def _execute(self) -> Any: with open(f"jobsets/{cfg.name}.yaml", "w") as f: logging.info("Output jobset to yaml file...") import yaml + yaml.dump(custom_object, f, default_flow_style=False) logging.info("Submitting JobSet body=%s api_kwargs=%s", custom_object, api_kwargs) return k8s.client.CustomObjectsApi().create_namespaced_custom_object(