From 76ab1ab66560213701943ecde368aedcd5de08e5 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 8 Oct 2024 13:06:28 -0400 Subject: [PATCH] Fix autograd.Function + NJT when an output grad is None (#136875) For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws: ``` RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers ``` This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875 Approved by: https://github.com/soulitzer --- test/test_nestedtensor.py | 30 +++++++++++++++++++++++++ torch/csrc/autograd/python_function.cpp | 12 +++++++++- torch/csrc/autograd/variable_info.cpp | 8 +++++-- torch/csrc/autograd/variable_info.h | 4 +++- 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1b4a84a484dce..96b29245c6d3f 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -7055,6 +7055,36 @@ def test_noncontiguous_to(self, device, dtype, contiguity): if nt._lengths is not None: self.assertEqual(nt3._lengths.device, other_device) + @dtypes(torch.float32) + def test_autograd_function_with_None_grad(self, device, dtype): + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + ctx.save_for_backward(inp) + out1 = inp + 1 + out2 = inp * 2 + return out1, out2 + + @staticmethod + def backward(ctx, grad_out1, grad_out2): + (inp,) = ctx.saved_tensors + return grad_out1 + grad_out2 + + f = MyFunction.apply + nt = random_nt_from_dims( + [5, None, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # Only use one of the autograd.Function outputs downstream so that the grad + # for the other output is None. We're testing that the engine can allocate + # correctly-shaped (NJT) zeros for the grad of the other output in this case. + (out1, _) = f(nt) + out1.backward(torch.ones_like(out1)) + @dtypes(torch.float64, torch.float32, torch.half) def test_jagged_padded_dense_conversion_kernels(self, device, dtype): values = torch.randn(10, 5, device=device, dtype=dtype) diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 415de56a49095..120f1934b7356 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -733,8 +733,18 @@ static void _wrap_outputs( PyTuple_SetItem(outputs, i, obj); } else { if (is_executable) { + // If one of the grad outputs is undefined, a correctly-shaped zeros + // should be used instead. To construct these for NJT, zeros_like() must + // be used until we have factory function support. // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - self->output_info.emplace_back(*wrapped_outputs[i]); + bool is_differentiable = + (non_differentiable.count( + wrapped_outputs[i]->unsafeGetTensorImpl()) == 0 && + isDifferentiableType(wrapped_outputs[i]->scalar_type())); + bool use_zeros_like = is_differentiable && num_outputs > 1 && + wrapped_outputs[i]->is_nested(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + self->output_info.emplace_back(*wrapped_outputs[i], use_zeros_like); } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i])); diff --git a/torch/csrc/autograd/variable_info.cpp b/torch/csrc/autograd/variable_info.cpp index bffd3250fb088..4b62e5fc67ce2 100644 --- a/torch/csrc/autograd/variable_info.cpp +++ b/torch/csrc/autograd/variable_info.cpp @@ -2,6 +2,7 @@ #include #else #include +#include #endif #include @@ -9,13 +10,14 @@ namespace torch::autograd { -VariableInfo::VariableInfo(const Variable& var) +VariableInfo::VariableInfo(const Variable& var, bool use_zeros_like) : layout(var.layout()), device(var.device()), scalar_type(var.scalar_type()), size(var.sym_sizes().vec()), requires_grad(var.requires_grad()), - is_empty(false) {} + is_empty(false), + the_var(use_zeros_like ? std::optional(var) : std::nullopt) {} VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {} @@ -23,6 +25,8 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { if (is_empty) { // Return undefined tensor. return at::Tensor(); + } else if (the_var.has_value()) { + return at::zeros_like(*the_var); } else { return at::zeros_symint( size, at::TensorOptions(scalar_type).device(device).layout(layout)); diff --git a/torch/csrc/autograd/variable_info.h b/torch/csrc/autograd/variable_info.h index 63e88deb0d547..e26804e7e55fc 100644 --- a/torch/csrc/autograd/variable_info.h +++ b/torch/csrc/autograd/variable_info.h @@ -6,7 +6,7 @@ namespace torch::autograd { struct TORCH_API VariableInfo { explicit VariableInfo(); - explicit VariableInfo(const Variable& var); + explicit VariableInfo(const Variable& var, bool use_zeros_like = false); Variable zeros(at::OptionalDeviceGuard& device_guard) const; @@ -16,6 +16,8 @@ struct TORCH_API VariableInfo { std::vector size; bool requires_grad; bool is_empty; + // needed for e.g. NJTs since they only support zeros_like() + std::optional the_var; }; } // namespace torch::autograd