Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing experimental gradient accumulation API #8584

Merged
merged 5 commits into from
Jan 22, 2025

Conversation

rpsilva-aws
Copy link
Collaborator

@rpsilva-aws rpsilva-aws commented Jan 16, 2025

In this PR, we introduce experimental.gradient_accumulation which leverages XLA's While op to accumulate gradients.

Training loop with traditional gradient accumulation
===> Preparing data..
Epoch 0 step 8 loss 1.1098170280456543
Epoch 0 step 16 loss 1.1719611883163452
Epoch 0 step 24 loss 3.453134536743164
Epoch 0 step 32 loss 2.518792152404785
Epoch 0 step 40 loss 6.67546272277832
Epoch 0 step 48 loss 4.609560012817383
Epoch 0 step 56 loss 5.953202247619629
Epoch 0 step 64 loss 1.325960636138916
Training loop with XLA's `While` gradient accumulation
===> Preparing data..
Epoch 0 step 8 loss 1.1098170280456543
Epoch 0 step 16 loss 1.1719611883163452
Epoch 0 step 24 loss 3.453134536743164
Epoch 0 step 32 loss 2.518792152404785
Epoch 0 step 40 loss 6.67546272277832
Epoch 0 step 48 loss 4.609560012817383
Epoch 0 step 56 loss 5.953202247619629
Epoch 0 step 64 loss 1.325960636138916
Accumulates gradients over multiple training steps using XLA's `While`
   operator to iterate over the leading dimension of the iterable tensors.
   The backward computation of the model is implicitly executed following the
   train_step operations.
Notes:
  The model tracing will happen entirely within the loop. Hence, it is
  assumed that `train_step` is purposefully encapsulated inside of the
  loop. Hence, it is not recommended to have any operation involving the
  model parameters outside of `train_step`.
Args:
  train_step: Training function that takes iterable tensors and carried
        tensors, and returns either a loss tensor or a tuple of (loss,
        *carried_outputs). The iterable tensor inputs to this function should
        disregard the leading dimension.
  iterable_tensors: Input tensors to iterate over. All tensors must have the
        same first dimension size which determines number of iterations. The
        underlying loop in the gradient accumulation will iterate through the
        leading dimension of these tensors.
  model: PyTorch model whose parameters will be updated. Note that the entire
          model computation will be traced and generated from within the loop.
  carried_tensors: Optional tensors passed and updated between iterations.
Returns:
  (accumulated_loss, carried_tensor0, carried_tensor1, ...): A tuple including
  the `accumulated_loss` and the same unpacked `carried_tensors` that were
  provided as inputs. In addition, the model parameter gradients, if
  applicable, contain the accumulated gradients.

@rpsilva-aws rpsilva-aws marked this pull request as ready for review January 16, 2025 19:28
@rpsilva-aws
Copy link
Collaborator Author

@jeffhataws @tengyifei

@tengyifei
Copy link
Collaborator

@rpsilva-aws do you plan on merging this into r2.6?

@rpsilva-aws
Copy link
Collaborator Author

@tengyifei Ideally, yes. It's perfectly fine for the 3-layer MLP, but we're seeing a small difference for Llama runs (difference being, from a previous local patch set that was just before cleaning some of the code), so we're just quickly identifying what it is.

@tengyifei
Copy link
Collaborator

Okay, please aim to sort out all critical issues by Jan 21 if you're aiming for 2.6 so that we could review and cherrypick it by Jan 22. 2.6 release is quicking drawing in and I would like a few days to test all the builds.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 3 times, most recently from 08831d6 to 567ccb5 Compare January 21, 2025 23:25
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 2 times, most recently from 4589eb2 to dfbef15 Compare January 22, 2025 01:10
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 2 times, most recently from 689dd0e to 1ce443c Compare January 22, 2025 04:52
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Jan 22, 2025

Tests are succeeding with the default allclose for the gradient accumulation. It's only the gradient checkpointing (existing test that I added in the previous PR) that has a relative difference of 4.6153e-4 for the second loss step (all other above 1e-5), likely due to the fact that we reorganized the train loop (some added no-op reshape ops). In any case, this is a safe change, since the added scope is specific to the test that is successful, but test-specific refactoring is having a minor hiccup with the former checkpointing test. I have now separated the tests to disambiguate the test failures, and will re-unify as a follow-up.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 2 times, most recently from b4def6a to a9ab7a5 Compare January 22, 2025 06:55
test/spmd/test_train_spmd_linear_model.py Outdated Show resolved Hide resolved
test/spmd/test_train_spmd_linear_model.py Outdated Show resolved Hide resolved
torch_xla/experimental/gradient_accumulation.py Outdated Show resolved Hide resolved
@rpsilva-aws rpsilva-aws requested a review from tengyifei January 22, 2025 17:22
@rpsilva-aws
Copy link
Collaborator Author

There was some setup issue with CPU's xla_op1. It is not relevant in this PR (well-scoped API), which succeed in all previous runs. The test is in xla_op3: https://github.com/pytorch/xla/actions/runs/12913551256/job/36016265400?pr=8584#step:13:676.
Pending TPU run.

@tengyifei tengyifei merged commit 36dcba3 into pytorch:master Jan 22, 2025
15 checks passed
@tengyifei
Copy link
Collaborator

Looks like everything passed after a retry

@rpsilva-aws rpsilva-aws deleted the rpsilva_grad_acc_v2 branch January 22, 2025 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants