diff --git a/CHANGELOG.md b/CHANGELOG.md index dc01689e..42cfb22d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275)) - Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267)) - Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270)) +- Fixed `grouped_matmul` when tensors are not contiguous ([#290](https://github.com/pyg-team/pyg-lib/pull/290)) ### Removed ## [0.3.0] - 2023-10-11 diff --git a/pyg_lib/csrc/ops/cuda/matmul_kernel.cu b/pyg_lib/csrc/ops/cuda/matmul_kernel.cu index 9909425a..6c98bef6 100644 --- a/pyg_lib/csrc/ops/cuda/matmul_kernel.cu +++ b/pyg_lib/csrc/ops/cuda/matmul_kernel.cu @@ -45,9 +45,9 @@ void run_grouped_gemm(const at::TensorList input, // Set arguments into gemm_args from input args for (size_t i = 0; i < num_matrices; ++i) { - auto new_in = input[i].contiguous(); - auto new_other = other[i].contiguous(); - auto new_out = out[i].contiguous(); + auto new_in = input[i]; + auto new_other = other[i]; + auto new_out = out[i]; auto m = new_in.size(0), k = new_other.size((int)(segment)), n = new_out.size(1); @@ -288,9 +288,14 @@ void grouped_matmul_out_kernel(const at::TensorList input, std::vector grouped_matmul_kernel(const at::TensorList input, const at::TensorList other) { std::vector out(input.size()); - for (size_t i = 0; i < input.size(); ++i) + std::vector input_contiguous(input.size()); + std::vector other_contiguous(other.size()); + for (size_t i = 0; i < input.size(); ++i) { + input_contiguous[i] = input[i].contiguous(); + other_contiguous[i] = other[i].contiguous(); out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)}); - grouped_matmul_out_kernel(input, other, out, false); + } + grouped_matmul_out_kernel(input_contiguous, other_contiguous, out, false); return out; } diff --git a/test/ops/test_matmul.py b/test/ops/test_matmul.py index b9001d3f..42a0ab91 100644 --- a/test/ops/test_matmul.py +++ b/test/ops/test_matmul.py @@ -39,7 +39,8 @@ def test_segment_matmul_autograd(dtype, device): @withCUDA @pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16]) -def test_grouped_matmul_autograd(dtype, device): +@pytest.mark.parametrize('transposed', [True, False]) +def test_grouped_matmul_autograd(dtype, transposed, device): if device.type == 'cuda' and dtype == torch.bfloat16: pytest.skip('CUDA does not support bfloat16') @@ -48,11 +49,19 @@ def test_grouped_matmul_autograd(dtype, device): torch.randn(6, 9, device=device, requires_grad=True), torch.randn(3, 32, device=device, requires_grad=True), ] - others = [ - torch.randn(16, 48, device=device, requires_grad=True), - torch.randn(9, 42, device=device, requires_grad=True), - torch.randn(32, 64, device=device, requires_grad=True), - ] + if transposed: + others_origin = [ + torch.randn(48, 16, device=device, requires_grad=True), + torch.randn(42, 9, device=device, requires_grad=True), + torch.randn(64, 32, device=device, requires_grad=True), + ] + others = [other.t() for other in others_origin] + else: + others = [ + torch.randn(16, 48, device=device, requires_grad=True), + torch.randn(9, 42, device=device, requires_grad=True), + torch.randn(32, 64, device=device, requires_grad=True), + ] biases = [ torch.randn(48, device=device, requires_grad=True), @@ -70,4 +79,7 @@ def test_grouped_matmul_autograd(dtype, device): sum([out.sum() for out in outs]).backward() for i in range(len(outs)): - assert others[i].grad.size() == others[i].size() + if transposed: + assert others_origin[i].grad.size() == others_origin[i].size() + else: + assert others[i].grad.size() == others[i].size()