Skip to content

Commit

Permalink
update dataloader
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Dec 18, 2023
1 parent 2ec3215 commit f48dd3d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 87 deletions.
98 changes: 11 additions & 87 deletions optimum/habana/accelerate/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import math
from typing import Callable, List, Optional, Union

import torch
from accelerate.data_loader import (
_PYTORCH_DATALOADER_KWARGS,
BatchSamplerShard,
DataLoaderDispatcher,
DataLoaderShard,
DataLoaderStateMixin,
IterableDatasetShard,
SeedableRandomSampler,
)
Expand All @@ -20,64 +19,17 @@
send_to_device,
slice_tensors,
)
from accelerate.utils.operations import _gpu_broadcast, is_tensor_information, recursively_apply
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from .state import GaudiAcceleratorState, GaudiPartialState
from .utils import GaudiDistributedType


def initialize_tensors(data_structure):
"""
Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
Returns:
The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
"""

def _initialize_tensor(tensor_info):
return torch.zeros(*tensor_info.shape, dtype=tensor_info.dtype)

return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)


def broadcast(tensor, from_process: int = 0):
"""
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data
Returns:
The same data structure as `tensor` with all tensors broadcasted to the proper device.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
return _gpu_broadcast(tensor, src=from_process)
return tensor


def broadcast_object_list(object_list, from_process: int = 0):
"""
Broadcast a list of picklable objects form one process to the others.
Args:
object_list (list of picklable objects):
The list of objects to broadcast. This list will be modified inplace.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data.
Returns:
The same list containing the objects from process 0.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
torch.distributed.broadcast_object_list(object_list, src=from_process, device="hpu")
return object_list
from .state import GaudiAcceleratorState
from .utils.operations import (
broadcast,
broadcast_object_list,
initialize_tensors,
)


class GaudiDataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
class GaudiDataLoaderDispatcher(DataLoaderDispatcher, DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Expand Down Expand Up @@ -112,7 +64,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, **kwargs)
DataLoader.__init__(self, dataset, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand Down Expand Up @@ -170,9 +122,9 @@ def __iter__(self):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
main_iterator = DataLoader.__iter__(self)
elif self.state.process_index == 0:
main_iterator = super().__iter__()
main_iterator = DataLoader.__iter__(self)
stop_iteration = False
self._stop_iteration = False
first_batch = None
Expand Down Expand Up @@ -237,34 +189,6 @@ def __iter__(self):
self.iteration += 1
self.end()

def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

def __len__(self):
whole_length = super().__len__()
if self.split_batches:
return whole_length
elif self._drop_last:
return whole_length // self.state.num_processes
else:
return math.ceil(whole_length / self.state.num_processes)

@property
def total_batch_size(self):
return (
self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
)

@property
def total_dataset_length(self):
return len(self.dataset)


def gaudi_prepare_data_loader(
dataloader: DataLoader,
Expand Down
74 changes: 74 additions & 0 deletions optimum/habana/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A set of basic tensor ops compatible with tpu, gpu, and multigpu
"""


import torch
from accelerate.utils.operations import _gpu_broadcast, is_tensor_information, recursively_apply

from ..state import GaudiPartialState
from ..utils import GaudiDistributedType


def initialize_tensors(data_structure):
"""
Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
Returns:
The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
"""

def _initialize_tensor(tensor_info):
return torch.zeros(*tensor_info.shape, dtype=tensor_info.dtype)

return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)


def broadcast(tensor, from_process: int = 0):
"""
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data
Returns:
The same data structure as `tensor` with all tensors broadcasted to the proper device.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
return _gpu_broadcast(tensor, src=from_process)
return tensor


def broadcast_object_list(object_list, from_process: int = 0):
"""
Broadcast a list of picklable objects form one process to the others.
Args:
object_list (list of picklable objects):
The list of objects to broadcast. This list will be modified inplace.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data.
Returns:
The same list containing the objects from process 0.
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
torch.distributed.broadcast_object_list(object_list, src=from_process, device="hpu")
return object_list

0 comments on commit f48dd3d

Please sign in to comment.