Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Grouped Matmul] Fix PyTorch memory leak when tensors are not contiguous #290

Merged
merged 5 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275))
- Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267))
- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270))
- Fixed `grouped_matmul` when tensors are not contiguous ([#290](https://github.com/pyg-team/pyg-lib/pull/290))
### Removed

## [0.3.0] - 2023-10-11
Expand Down
15 changes: 10 additions & 5 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ void run_grouped_gemm(const at::TensorList input,

// Set arguments into gemm_args from input args
for (size_t i = 0; i < num_matrices; ++i) {
auto new_in = input[i].contiguous();
auto new_other = other[i].contiguous();
auto new_out = out[i].contiguous();
auto new_in = input[i];
auto new_other = other[i];
auto new_out = out[i];
auto m = new_in.size(0), k = new_other.size((int)(segment)),
n = new_out.size(1);

Expand Down Expand Up @@ -288,9 +288,14 @@ void grouped_matmul_out_kernel(const at::TensorList input,
std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
const at::TensorList other) {
std::vector<at::Tensor> out(input.size());
for (size_t i = 0; i < input.size(); ++i)
std::vector<at::Tensor> input_contiguous(input.size());
std::vector<at::Tensor> other_contiguous(other.size());
for (size_t i = 0; i < input.size(); ++i) {
input_contiguous[i] = input[i].contiguous();
other_contiguous[i] = other[i].contiguous();
out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)});
grouped_matmul_out_kernel(input, other, out, false);
}
grouped_matmul_out_kernel(input_contiguous, other_contiguous, out, false);

return out;
}
Expand Down
26 changes: 19 additions & 7 deletions test/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_segment_matmul_autograd(dtype, device):

@withCUDA
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_grouped_matmul_autograd(dtype, device):
@pytest.mark.parametrize('transposed', [True, False])
def test_grouped_matmul_autograd(dtype, transposed, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

Expand All @@ -48,11 +49,19 @@ def test_grouped_matmul_autograd(dtype, device):
torch.randn(6, 9, device=device, requires_grad=True),
torch.randn(3, 32, device=device, requires_grad=True),
]
others = [
torch.randn(16, 48, device=device, requires_grad=True),
torch.randn(9, 42, device=device, requires_grad=True),
torch.randn(32, 64, device=device, requires_grad=True),
]
if transposed:
others_origin = [
torch.randn(48, 16, device=device, requires_grad=True),
torch.randn(42, 9, device=device, requires_grad=True),
torch.randn(64, 32, device=device, requires_grad=True),
]
others = [other.t() for other in others_origin]
else:
others = [
torch.randn(16, 48, device=device, requires_grad=True),
torch.randn(9, 42, device=device, requires_grad=True),
torch.randn(32, 64, device=device, requires_grad=True),
]

biases = [
torch.randn(48, device=device, requires_grad=True),
Expand All @@ -70,4 +79,7 @@ def test_grouped_matmul_autograd(dtype, device):

sum([out.sum() for out in outs]).backward()
for i in range(len(outs)):
assert others[i].grad.size() == others[i].size()
if transposed:
assert others_origin[i].grad.size() == others_origin[i].size()
else:
assert others[i].grad.size() == others[i].size()
Loading