Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ParallelLoader with Flexible Tensor Accumulation #8616

Open
rpsilva-aws opened this issue Jan 23, 2025 · 2 comments
Open

ParallelLoader with Flexible Tensor Accumulation #8616

rpsilva-aws opened this issue Jan 23, 2025 · 2 comments

Comments

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Jan 23, 2025

🚀 Feature

Currently, when loading device data with multi-processing, we leverage MpDeviceLoader, which takes care of copying the tensors to the device (returning per_device_loader), wrapping an existing data loader. Users of torch-xla's MpDeviceLoader face constraints when requiring specific tensor shapes for operations like gradient accumulation. The current implementation only supports having the batch dimension along a single dimension, forcing users to manually reshape data and handle sharding, which can lead to suboptimal cases (see below), particularly because the data is transferred to the device as is.

Extend ParallelLoader with a flexible accumulation feature that:

  1. Adds new parameters:
    • accumulation_dim (int) = 1: Axis for tensor accumulation
    • accumulation_size (int): Number of elements to accumulate
  2. Maintains compatibility with existing batch_dim parameter
  3. Handles accumulation at the ParallelLoader worker level
  4. Preserves batches_per_execution semantics

For instance, provided accumulation_dim=0, accumulation_size=16 and batch_dim=1, with the underlying train loader returning 32 batches at a time, we expected to get [16, 32, seq_dim]. Similarly, accumulation_dim=1, accumulation_size=16 would return [32, 16, seq_dim], since batch_dim is currently defaulted to 0. If the batch_dim and accumulation_dim are equal, we throw an error.

Hence, the customer can specify a train loader with the intended batch size (not counting in for the accumulation), and allowing them to explicitly specify the accumulation dimension and size for which the batch is stacked on. A non-breaking requirement is that customer need to specify batch_dim, if they use accumulation_size > 0, since it currently defaults to 0.

Motivation

The current data loader is tailored to pull in a certain number of batches along the batch dimension (defaulted to 0). However, in some cases, the training may require a different or specific shape, such as with gradient accumulation (e.g. [4, 4, 4096], instead of [16, 4096] with batch size = 16, for instance DP = 4 and 4 gradient accumulation steps). The problem with the current setup, is that it adds the constraint that the data is accumulated all across a single dimension, and that itself is later sent to the device.

# Assuming that we have an arbitrary DP degree for the batch inputs.

# Assume train_loader has a batch size: `micro_bs * data_parallelism_degree * gradient_accumulation_steps`
train_device_loader = pl.MpDeviceLoader(train_dataloader, xm.xla_device())

for data, target in train_loader:

  # 1) Issue: Sharding across the batch size dimension, which is semantically incorrect. Following operations will likely be incompatible, as we expect a different axis. In the case of gradient accumulation, we expect one of the axis to account for the GA steps.
  xm.mark_sharding(data, mesh, ('data', None))
  xm.mark_sharding(target, mesh, ('data', None))
  ...

  # 2) Issue: Sharding across the batch size dimension after reshaping. This will force an all-gather at the end, before reshaping to the original shape, to recollect the resulting data. It's also less trivial to the users. In addition, the input itself will be replicated.
  data.reshape(gradient_accumulation_steps, -1, *data.shape[1:])
  target.reshape(gradient_accumulation_steps, -1, *target.shape[1:])
  xm.mark_sharding(data, mesh, (None, 'data', None))
  xm.mark_sharding(target, mesh, (None, 'data', None))
  ...

  # 3) Issue: Not trivial, and while it fixes the replicated input, similar to 2),this will force an all-gather at the end, before reshaping to the original shape, to recollect the resulting data.
  xm.mark_sharding(data, mesh, ('data', None))
  xm.mark_sharding(target, mesh, ('data', None))
  data.reshape(gradient_accumulation_steps, -1, *data.shape[1:])
  target.reshape(gradient_accumulation_steps, -1, *target.shape[1:])
  xm.mark_sharding(data, mesh, (None, 'data', None))
  xm.mark_sharding(target, mesh, (None, 'data', None))
  ...

Pitch

Instead, this proposes doing:

# Assume train_loader has a batch size: `micro_bs * data_parallelism_degree`
loader = MpDeviceLoader(train_dataloader, xm.xla_device(), batch_dim=1, accumulation_dim=0, accumulation_size=16)
for data, target in train_loader:
  xm.mark_sharding(data, mesh, (None, 'data', None))
  xm.mark_sharding(target, mesh, (None, 'data', None))

Alternatives

Additional context

TODO - add real-life example

@rpsilva-aws
Copy link
Collaborator Author

There's some added complexity to how we can stack any arbitrary dataset, but that is implementation specific.

@rpsilva-aws
Copy link
Collaborator Author

cc: @tengyifei, @ManfeiBai, do you folks know who can evaluate any risk in the proposal? I'll work on a draft change if none.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant