Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Add gather/scatter support 1D tensor (#229)
Browse files Browse the repository at this point in the history
This PR is to add gather/scatter support 1D tensor on python level, as WholeGraph should support basic indexing operations for both 1D (array) and 2D (matrix) wholememory tensors.   Without this PR, if with 1D wholememory tensor, gather/scatter op does not work, e.g., https://github.com/rapidsai/wholegraph/blob/0efba33835d6e4e104b5d7101a91e0ea55a6ca53/python/pylibwholegraph/pylibwholegraph/torch/tensor.py#L89



To test, run 
```
pytest --cache-clear  --import-mode=append  tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s
```

**Remaining issue:**

On my local test with single GPU, the test can pass.   
For multiGPU setup, gather op works fine, but 1D scatter seems not working as it would crash at:
https://github.com/rapidsai/wholegraph/blob/2e963b98aa6027c300d60e839010d3dd8ca422eb/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py#L108 with incorrect scatter outputs: `Indices where allclose fails:  tensor([0., 0., 0.,  ..., 0., 0., 0.]) tensor([  1435.,   1439.,   1443.,  ..., 257703., 257707., 257711.]) `



@linhu-nv   Can you please take a look? Does scatter_op suppose to work with 1D wholememory tensor?

Authors:
  - Chang Liu (https://github.com/chang-l)

Approvers:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

URL: #229
  • Loading branch information
chang-l authored Nov 22, 2024
1 parent f82c3e7 commit 9a2bb57
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


def gen_int_embedding(indice_tensor, embedding_dim, output_type):
if embedding_dim == 0:
embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op)
indice_count = indice_tensor.shape[0]
indice_part = (
indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim)
Expand Down Expand Up @@ -54,9 +56,14 @@ def scatter_gather_test_cast(
"Rank=%d testing scatter gather with embedding_count=%d, embedding_dim=%d, indice_count=%d, dt=%s, mt=%s, ml=%s"
% (world_rank, embedding_count, embedding_dim, indice_count, dt, mt, ml)
)
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)
if embedding_dim == 0:
wm_embedding = wmb.create_wholememory_array(
dt, embedding_count, wm_comm, mt, ml, entry_partition
)
else:
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)

scatter_indice = torch.arange(
world_rank, embedding_count, world_size, dtype=torch.int64
Expand Down Expand Up @@ -91,9 +98,13 @@ def scatter_gather_test_cast(
local_ref_start = wm_embedding.get_local_entry_start()
local_ref_count = wm_embedding.get_local_entry_count()
assert local_start == local_ref_start
assert local_tensor_cuda.dim() == 2
assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1
assert local_tensor_cuda.shape[0] == local_ref_count
assert local_tensor_cuda.shape[1] == embedding_dim
if local_tensor_cuda.dim() == 2:
assert local_tensor_cuda.shape[1] == embedding_dim
else:
# unsqueeze to 2D for comparison
local_tensor_cuda = local_tensor_cuda.unsqueeze(1)

local_tensor = local_tensor_cuda.cpu()
local_indices = torch.arange(local_ref_start, local_ref_start + local_ref_count, dtype=torch.int64)
Expand All @@ -114,6 +125,9 @@ def scatter_gather_test_cast(
)
embedding_after_gather = embedding_after_gather_cuda.cpu()
ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float)
if embedding_after_gather.dim() == 1:
# unsqueeze to 2D for comparison
embedding_after_gather = embedding_after_gather.unsqueeze(1)
# print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % (
# gather_indice, embedding_after_gather, ref_embedding_gather))
assert torch.allclose(embedding_after_gather, ref_embedding_gather)
Expand All @@ -134,7 +148,6 @@ def routine_func(world_rank: int, world_size: int):
wm_comm = wm_comm.wmb_comm

embedding_count = 1024 * 256 * world_size + 3
embedding_dim = 256
indice_count = 100001
dt = wmb.WholeMemoryDataType.DtFloat
entry_partition = random_partition(embedding_count, world_size)
Expand All @@ -150,11 +163,12 @@ def routine_func(world_rank: int, world_size: int):
wmb.WholeMemoryMemoryLocation.MlHost,
wmb.WholeMemoryMemoryLocation.MlDevice,
]:
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition
)
# scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False)
for embedding_dim in [0, 256]: # 0 is for 1D tensor
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition
)
# scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False)
wmb.finalize()


Expand Down
12 changes: 8 additions & 4 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gather(self,
*,
force_dtype: Union[torch.dtype, None] = None):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_dim = self.shape[1] if self.dim() == 2 else 1
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = (
Expand All @@ -80,15 +80,19 @@ def gather(self,
wrap_torch_tensor(output_tensor),
get_wholegraph_env_fns(),
get_stream())
return output_tensor
return output_tensor.view(-1) if self.dim() == 1 else output_tensor

def scatter(self,
input_tensor: torch.Tensor,
indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert input_tensor.dim() == self.dim()
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
if self.dim() == 2:
assert input_tensor.shape[1] == self.shape[1]
else:
# unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op
input_tensor = input_tensor.unsqueeze(1)
wmb.wholememory_scatter_op(wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
self.wmb_tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor(
assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64
if torch_output_dtype is None:
torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype)

embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1
output_tensor = torch.empty(
[indices_tensor.shape[0], wholememory_tensor.shape[1]],
[indices_tensor.shape[0], embedding_dim],
device="cuda",
dtype=torch_output_dtype,
requires_grad=requires_grad,
Expand All @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor


def wholememory_scatter_functor(
Expand Down

0 comments on commit 9a2bb57

Please sign in to comment.