Skip to content

Commit

Permalink
refactor pathways jobset to new spec
Browse files Browse the repository at this point in the history
  • Loading branch information
jesus-orozco committed Jan 10, 2025
1 parent 357a322 commit 89ac1ea
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,6 @@ def _build_container(self, job_type: str = None) -> Nested[Any]:
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}

container_name = cfg.name
args = []
image = self._bundler.id(cfg.name)
ports = [
Expand All @@ -577,25 +576,19 @@ def _build_container(self, job_type: str = None) -> Nested[Any]:
]

if self.using_pathways:
container_name = f"{cfg.name}-{job_type}"
staging_location = f"{cfg.output_dir}/pathways-staging/tmp"
# cluster = flags.FLAGS.cluster or gcp_settings("gke_cluster", required=False, fv=flags.FLAGS)
# pathways_proxy_port, pathways_rm_port = 38676, 38677
# rm_address = f"{cfg.name}-rm-0-0.{cfg.name}.default.svc.{cluster}-domain:{pathways_port}"
# rm_address = f"{cfg.name}-rm-0-0.{cfg.name}:{pathways_port}"
env_vars.update(
# dump XLA flags to GCS bucket for troubleshooting purposes.
XLA_FLAGS=f"--xla_dump_to={cfg.output_dir}/{cfg.name}/xla/"
XLA_FLAGS="--xla_dump_to=gs://ttl-30d-us-central2/axlearn/users/jesusfc/pathways/v6e/xla/"
)

if job_type == "pathways-head":
# resources["limits"]["memory"] = "100Gi"
# resources["limits"]["cpu"] = "24"
env_vars.update(
JAX_BACKEND_TARGET="grpc://$(HOST_ADDRESS):38681",
# JAX_BACKEND_TARGET="grpc://$(HOST_ADDRESS):29000",
JAX_BACKEND_TARGET=f"grpc://{cfg.name}-{job_type}-0-0.{cfg.name}:29000",
XCLOUD_ENVIRONMENT="GCP",
JAX_PLATFORMS="proxy",
ENABLE_PATHWAYS_PERSISTENCE="true",
ENABLE_PATHWAYS_PERSISTENCE="1",
TPU_SKIP_MDS_QUERY="true",
)
elif job_type == "pathways-workers":
Expand All @@ -604,13 +597,13 @@ def _build_container(self, job_type: str = None) -> Nested[Any]:
)
args.extend(
[
"--server_port=38679",
"--resource_manager_address=$(PATHWAYS_HEAD):38677",
"--server_port=29001",
"--resource_manager_address=$(PATHWAYS_HEAD):29001",
f"--gcs_scratch_location={staging_location}",
]
)
image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest"
ports.append(dict(containerPort=38679))
ports.append(dict(containerPort=29001))
resources = {"limits": {"google.com/tpu": system.chips_per_vm}}

k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
Expand Down Expand Up @@ -676,7 +669,7 @@ def _build_container(self, job_type: str = None) -> Nested[Any]:
)

return dict(
name=container_name,
name=cfg.name,
image=image,
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#tpu-chips-node-pool
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice#run_workload
Expand Down Expand Up @@ -759,18 +752,18 @@ def _build_pathways_containers(self) -> list[dict]:
}
],
args=[
"--resource_manager_address=$(PATHWAYS_HEAD):38677",
"--server_port=38681",
"--resource_manager_address=$(PATHWAYS_HEAD):29001",
"--server_port=29000",
f"--gcs_scratch_location={staging_location}",
],
ports=[dict(containerPort=38681)],
# resources={"limits": {"cpu": "24", "memory": "100Gi"},},
ports=[dict(containerPort=29000)],
),
dict(
name="pathways-rm",
image="us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest",
# https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers
# SideCar container is an init container with restartPolicy as "Always".
restartPolicy="Always",
env=[
{
"name": "HOST_ADDRESS",
Expand All @@ -786,13 +779,12 @@ def _build_pathways_containers(self) -> list[dict]:
},
],
args=[
"--server_port=38677",
"--server_port=29001",
"--node_type=resource_manager",
f"--instance_count={cfg.accelerator.num_replicas}",
f"--instance_type={tpu_type}:{system.topology}",
f"--gcs_scratch_location={staging_location}",
],
# resources={"limits": {"cpu": "4", "memory": "8Gi"},},
),
]

Expand Down Expand Up @@ -942,6 +934,7 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]:

if job_type == "pathways-head":
# selector.update({"pathways-head": "true"})
selector.update({"cloud.google.com/gke-nodepool": "pathways-head"})
initContainers.extend(self._build_pathways_containers())
else:
selector.update(
Expand Down Expand Up @@ -1006,9 +999,9 @@ def _build_job(self, job_type: str = None) -> Nested[Any]:
template=self._build_pod(job_type),
)
elif job_type == "pathways-head":
annotations.update(
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname"}
)
#annotations.update(
# {"alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname"}
#)
spec.update(
parallelism=1,
completions=1,
Expand Down

0 comments on commit 89ac1ea

Please sign in to comment.