Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use env to skip PJRT initialize #8609

Merged
merged 1 commit into from
Jan 23, 2025
Merged

use env to skip PJRT initialize #8609

merged 1 commit into from
Jan 23, 2025

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Jan 22, 2025

We skip the PJRT Megascale initialization by controlling the env.

This is a temporary fix, and is supposed to be rolled back.

Check #8609 (comment) for detailed motivation.

@zpcore zpcore requested review from tengyifei and bhavya01 January 22, 2025 20:10
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest landing this after the libtpu change is approved.

torch_xla/experimental/custom_kernel.py Show resolved Hide resolved
@zpcore
Copy link
Collaborator Author

zpcore commented Jan 22, 2025

@tengyifei , shall we cherry pick this PR to 2.6 release?

@tengyifei
Copy link
Collaborator

@zpcore cherrypicking is fine with me.

@zpcore
Copy link
Collaborator Author

zpcore commented Jan 23, 2025

The test failed. I think it is due to pytorch/pytorch#142859.

They have reverted the PR.

@tengyifei
Copy link
Collaborator

Ack

@tengyifei tengyifei merged commit 557d9f3 into master Jan 23, 2025
11 of 12 checks passed
@tengyifei
Copy link
Collaborator

@zpcore thanks. next step is to follow the process in #8455 to create a cherrypick PR

@bhavya01
Copy link
Collaborator

Retrospective LGTM!

zpcore added a commit that referenced this pull request Jan 23, 2025
@zpcore zpcore deleted the piz/multipod_hack branch January 23, 2025 23:03
@miladm
Copy link
Collaborator

miladm commented Jan 31, 2025

Thanks @zpcore - can we add enough details to PR descriptions to help folks without context understand the intent of the contribution more clearly please?

@zpcore
Copy link
Collaborator Author

zpcore commented Jan 31, 2025

Thanks @zpcore - can we add enough details to PR descriptions to help folks without context understand the intent of the contribution more clearly please?

The issue in multipod run is that MegascaleXLA(MXLA) will trigger device discovery when we initialize PJRT runtime with the TPU backend. With the introduction of Pallas kernel, we did an extra MXLA trigger when call jax.jit(). Thus there are more than one device discovery been executed. Everytime we execute device discovery, all device will be assigned a ID. This will cause device ID mismatch after the second device discovery thus cause the device communication hang.

The hacky way to fix the issue is to use enviroment variable to control the device discovery won't be triggered when call jax.jit. This PR works with the fix we made internally in libtpu source code:

 const char* skip_megascale_pjrt_client = std::getenv("SKIP_MEGASCALE_PJRT_CLIENT");
  bool skip_megascale = false;
  if (skip_megascale_pjrt_client != nullptr) {
    skip_megascale = true;
  }
  if (absl::GetFlag(FLAGS_megascale_num_slices) != 1 && !skip_megascale) {
    client = xla::MegaScalePjRtClient::CreateMegaScalePjRtClient(
        std::move(tpu_client));
    ...
  }

With this fix, MegaScalePjRtClient will only be triggered in place (e.g.,)

devices = torch_xla._XLAC._xla_get_devices()
,
where we call runtime::GetComputationClient() and initialize the client.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants