You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, when loading device data with multi-processing, we leverage MpDeviceLoader, which takes care of copying the tensors to the device (returning per_device_loader), wrapping an existing data loader. Users of torch-xla's MpDeviceLoader face constraints when requiring specific tensor shapes for operations like gradient accumulation. The current implementation only supports having the batch dimension along a single dimension, forcing users to manually reshape data and handle sharding, which can lead to suboptimal cases (see below), particularly because the data is transferred to the device as is.
Extend ParallelLoader with a flexible accumulation feature that:
Adds new parameters:
accumulation_dim (int) = 1: Axis for tensor accumulation
accumulation_size (int): Number of elements to accumulate
Maintains compatibility with existing batch_dim parameter
Handles accumulation at the ParallelLoader worker level
Preserves batches_per_execution semantics
For instance, provided accumulation_dim=0, accumulation_size=16 and batch_dim=1, with the underlying train loader returning 32 batches at a time, we expected to get [16, 32, seq_dim]. Similarly, accumulation_dim=1, accumulation_size=16 would return [32, 16, seq_dim], since batch_dim is currently defaulted to 0. If the batch_dim and accumulation_dim are equal, we throw an error.
Hence, the customer can specify a train loader with the intended batch size (not counting in for the accumulation), and allowing them to explicitly specify the accumulation dimension and size for which the batch is stacked on. A non-breaking requirement is that customer need to specify batch_dim, if they use accumulation_size > 0, since it currently defaults to 0.
Motivation
The current data loader is tailored to pull in a certain number of batches along the batch dimension (defaulted to 0). However, in some cases, the training may require a different or specific shape, such as with gradient accumulation (e.g. [4, 4, 4096], instead of [16, 4096] with batch size = 16, for instance DP = 4 and 4 gradient accumulation steps). The problem with the current setup, is that it adds the constraint that the data is accumulated all across a single dimension, and that itself is later sent to the device.
# Assuming that we have an arbitrary DP degree for the batch inputs.# Assume train_loader has a batch size: `micro_bs * data_parallelism_degree * gradient_accumulation_steps`train_device_loader=pl.MpDeviceLoader(train_dataloader, xm.xla_device())
fordata, targetintrain_loader:
# 1) Issue: Sharding across the batch size dimension, which is semantically incorrect. Following operations will likely be incompatible, as we expect a different axis. In the case of gradient accumulation, we expect one of the axis to account for the GA steps.xm.mark_sharding(data, mesh, ('data', None))
xm.mark_sharding(target, mesh, ('data', None))
...
# 2) Issue: Sharding across the batch size dimension after reshaping. This will force an all-gather at the end, before reshaping to the original shape, to recollect the resulting data. It's also less trivial to the users. In addition, the input itself will be replicated.data.reshape(gradient_accumulation_steps, -1, *data.shape[1:])
target.reshape(gradient_accumulation_steps, -1, *target.shape[1:])
xm.mark_sharding(data, mesh, (None, 'data', None))
xm.mark_sharding(target, mesh, (None, 'data', None))
...
# 3) Issue: Not trivial, and while it fixes the replicated input, similar to 2),this will force an all-gather at the end, before reshaping to the original shape, to recollect the resulting data.xm.mark_sharding(data, mesh, ('data', None))
xm.mark_sharding(target, mesh, ('data', None))
data.reshape(gradient_accumulation_steps, -1, *data.shape[1:])
target.reshape(gradient_accumulation_steps, -1, *target.shape[1:])
xm.mark_sharding(data, mesh, (None, 'data', None))
xm.mark_sharding(target, mesh, (None, 'data', None))
...
🚀 Feature
Currently, when loading device data with multi-processing, we leverage
MpDeviceLoader
, which takes care of copying the tensors to the device (returningper_device_loader
), wrapping an existing data loader. Users of torch-xla's MpDeviceLoader face constraints when requiring specific tensor shapes for operations like gradient accumulation. The current implementation only supports having the batch dimension along a single dimension, forcing users to manually reshape data and handle sharding, which can lead to suboptimal cases (see below), particularly because the data is transferred to the device as is.Extend ParallelLoader with a flexible accumulation feature that:
For instance, provided
accumulation_dim=0
,accumulation_size=16
andbatch_dim=1
, with the underlying train loader returning 32 batches at a time, we expected to get[16, 32, seq_dim]
. Similarly,accumulation_dim=1
,accumulation_size=16
would return[32, 16, seq_dim]
, sincebatch_dim
is currently defaulted to 0. If thebatch_dim
andaccumulation_dim
are equal, we throw an error.Hence, the customer can specify a train loader with the intended batch size (not counting in for the accumulation), and allowing them to explicitly specify the accumulation dimension and size for which the batch is stacked on. A non-breaking requirement is that customer need to specify
batch_dim
, if they useaccumulation_size
> 0, since it currently defaults to0
.Motivation
The current data loader is tailored to pull in a certain number of batches along the batch dimension (defaulted to 0). However, in some cases, the training may require a different or specific shape, such as with gradient accumulation (e.g. [4, 4, 4096], instead of [16, 4096] with batch size = 16, for instance DP = 4 and 4 gradient accumulation steps). The problem with the current setup, is that it adds the constraint that the data is accumulated all across a single dimension, and that itself is later sent to the device.
Pitch
Instead, this proposes doing:
Alternatives
Additional context
TODO - add real-life example
The text was updated successfully, but these errors were encountered: