Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix DataLoaderDispatcher issue in Gaudi #600

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 177 additions & 2 deletions optimum/habana/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,185 @@
IterableDatasetShard,
SeedableRandomSampler,
)
from accelerate.utils import RNGType
from accelerate.state import GradientState
from accelerate.utils import (
RNGType,
concatenate,
find_batch_size,
get_data_structure,
is_torch_version,
send_to_device,
slice_tensors,
)
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from .state import GaudiAcceleratorState
from .utils.operations import (
broadcast,
broadcast_object_list,
initialize_tensors,
)


class GaudiDataLoaderDispatcher(DataLoaderDispatcher, DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.

Args:
split_batches (`bool`, *optional*, defaults to `False`):
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
`num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
`dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.

**Available attributes:**

- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
number of processes

- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
"""

def __init__(
self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, slice_fn=None, **kwargs
):
shuffle = False
if is_torch_version(">=", "1.11.0"):
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe

# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
DataLoader.__init__(self, dataset, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)

self.gradient_state = GradientState()
self.state = GaudiAcceleratorState()
self._drop_last = _drop_last
self.skip_batches = skip_batches

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
batches.append(next(iterator))
batch = concatenate(batches, dim=0)
# In both cases, we need to get the structure of the batch that we will broadcast on other
# processes to initialize the tensors with the right shape.
# data_structure, stop_iteration
batch_info = [get_data_structure(batch), False]
except StopIteration:
batch_info = [None, True]
else:
batch_info = [None, self._stop_iteration]
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
broadcast_object_list(batch_info)
self._stop_iteration = batch_info[1]
if self._stop_iteration:
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
if not self.split_batches and not self._drop_last:
if self.state.process_index == 0 and len(batches) > 0:
batch = concatenate(batches, dim=0)
batch_info = [get_data_structure(batch), False]
else:
batch_info = [None, True]
broadcast_object_list(batch_info)
return batch, batch_info

def __iter__(self):
self.begin()
self.set_epoch(self.iteration)
main_iterator = None
if is_torch_version(">=", "2.0.1"):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = DataLoader.__iter__(self)
elif self.state.process_index == 0:
main_iterator = DataLoader.__iter__(self)
stop_iteration = False
self._stop_iteration = False
first_batch = None
next_batch, next_batch_info = self._fetch_batches(main_iterator)
batch_index = 0
while not stop_iteration:
batch, batch_info = next_batch, next_batch_info

if self.state.process_index != 0:
# Initialize tensors on other processes than process 0.
batch = initialize_tensors(batch_info[0])
batch = send_to_device(batch, self.state.device)
# Broadcast the batch before splitting it.
batch = broadcast(batch, from_process=0)

if not self._drop_last and first_batch is None:
# We keep at least num processes elements of the first batch to be able to complete the last batch
first_batch = self.slice_fn(
batch,
slice(0, self.state.num_processes),
process_index=self.state.process_index,
num_processes=self.state.num_processes,
)

if batch is None:
raise ValueError(
f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
)

observed_batch_size = find_batch_size(batch)
batch_size = observed_batch_size // self.state.num_processes

stop_iteration = self._stop_iteration
if not stop_iteration:
# We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
# the dataloader since the number of batches is a round multiple of the number of processes.
next_batch, next_batch_info = self._fetch_batches(main_iterator)
# next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
if self._stop_iteration and next_batch_info[0] is None:
stop_iteration = True

if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
# If the last batch is not complete, let's add the first batch to it.
batch = concatenate([batch, first_batch], dim=0)
# Batch size computation above is wrong, it's off by 1 so we fix it.
batch_size += 1

data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
batch = self.slice_fn(
batch,
data_slice,
process_index=self.state.process_index,
num_processes=self.state.num_processes,
)

if stop_iteration:
self.end_of_dataloader = True
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
self.iteration += 1
self.end()


def gaudi_prepare_data_loader(
Expand Down Expand Up @@ -204,7 +379,7 @@ def gaudi_prepare_data_loader(
dataloader.batch_sampler.sampler = sampler
if dispatch_batches:
kwargs.pop("generator")
dataloader = DataLoaderDispatcher(
dataloader = GaudiDataLoaderDispatcher(
new_dataset,
split_batches=split_batches,
batch_sampler=new_batch_sampler,
Expand Down
74 changes: 74 additions & 0 deletions optimum/habana/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A set of basic tensor ops compatible with tpu, gpu, and multigpu
"""


import torch
from accelerate.utils.operations import _gpu_broadcast, is_tensor_information, recursively_apply

from ..state import GaudiPartialState
from ..utils import GaudiDistributedType


def initialize_tensors(data_structure):
"""
Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].

Returns:
The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
"""

def _initialize_tensor(tensor_info):
return torch.zeros(*tensor_info.shape, dtype=tensor_info.dtype)

return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)


def broadcast(tensor, from_process: int = 0):
"""
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.

Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data

Returns:
The same data structure as `tensor` with all tensors broadcasted to the proper device.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
return _gpu_broadcast(tensor, src=from_process)
return tensor


def broadcast_object_list(object_list, from_process: int = 0):
"""
Broadcast a list of picklable objects form one process to the others.

Args:
object_list (list of picklable objects):
The list of objects to broadcast. This list will be modified inplace.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data.

Returns:
The same list containing the objects from process 0.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
torch.distributed.broadcast_object_list(object_list, src=from_process, device="hpu")
return object_list
Loading