Skip to content

Commit

Permalink
torch.cat and torch.stack batching rules (#43798)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#43798

These are relatively straightforward.

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23405000

Pulled By: zou3519

fbshipit-source-id: 65c78da3dee43652636bdb0a65b636fca69e765d
  • Loading branch information
zou3519 authored and facebook-github-bot committed Sep 1, 2020
1 parent dbc4218 commit 9b98bce
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
28 changes: 28 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,30 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}

Tensor cat_batching_rule(TensorList tensors, int64_t dim) {
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
return physical_views[0].newLogicalFromPhysical(result);
}

Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
// NB: stack wraps the dimensionality to (logical dim + 1), so we have to
// manually handle that here.
auto dim_physical =
physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
auto result = at::stack(physical_tensors, dim_physical);
return physical_views[0].newLogicalFromPhysical(result);
}

// I am quite sad that we need to register operators with exploded TensorOptions,
// even though the native:: implementations can use TensorOptions&.
// This also makes it hard to metaprogram: i.e., we can't use
Expand Down Expand Up @@ -586,6 +610,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("dot", dot_batching_rule);
m.impl("bmm", bmm_batching_rule);
m.impl("mm", mm_batching_rule);

// cat/stack
m.impl("cat", cat_batching_rule);
m.impl("stack", stack_batching_rule);
}

} // namespace at
41 changes: 40 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_unsupported_op_err_msg(self):

# The fallback doesn't support TensorList
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
vmap(lambda t: torch.stack([t]))(tensor)
vmap(lambda t: torch.atleast_1d([t]))(tensor)

# Don't support non-tensor returns. This is a limitation of vmap;
# functions that don't return tensors must be special cased
Expand Down Expand Up @@ -906,6 +906,25 @@ def test_bmm(self):
test(vmap(op, in_dims=(0, None)),
(torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))

def test_cat(self):
test = self._vmap_test
B0, B1 = 5, 7

# Quick hack b/c vmap can't accept a list of tensors as an argument
def get_op(dim):
def op(*tensors):
return torch.cat(tensors, dim=dim)
return op

test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
test(vmap(get_op(0), in_dims=(0, None)),
(torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0))
test(vmap(get_op(0), in_dims=(0, 0)),
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))

def test_chunk(self):
test = self._vmap_view_test
op = torch.chunk
Expand Down Expand Up @@ -1074,6 +1093,26 @@ def test_select(self):
test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)

def test_stack(self):
test = self._vmap_test
B0, B1 = 5, 7

# Quick hack b/c vmap can't accept a list of tensors as an argument
def get_op(dim):
def op(*tensors):
return torch.stack(tensors, dim=dim)
return op

test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
test(vmap(get_op(0), in_dims=(0, None)),
(torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0))
test(vmap(get_op(0), in_dims=(0, 0)),
(torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0))


def test_slice(self):
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
Expand Down

0 comments on commit 9b98bce

Please sign in to comment.