Skip to content

Commit

Permalink
Fix autograd.Function + NJT when an output grad is None (pytorch#136875)
Browse files Browse the repository at this point in the history
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```

This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: pytorch#136875
Approved by: https://github.com/soulitzer
  • Loading branch information
jbschlosser authored and pytorchmergebot committed Oct 8, 2024
1 parent 5e3e1c0 commit 76ab1ab
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
30 changes: 30 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7055,6 +7055,36 @@ def test_noncontiguous_to(self, device, dtype, contiguity):
if nt._lengths is not None:
self.assertEqual(nt3._lengths.device, other_device)

@dtypes(torch.float32)
def test_autograd_function_with_None_grad(self, device, dtype):
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
ctx.save_for_backward(inp)
out1 = inp + 1
out2 = inp * 2
return out1, out2

@staticmethod
def backward(ctx, grad_out1, grad_out2):
(inp,) = ctx.saved_tensors
return grad_out1 + grad_out2

f = MyFunction.apply
nt = random_nt_from_dims(
[5, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)

# Only use one of the autograd.Function outputs downstream so that the grad
# for the other output is None. We're testing that the engine can allocate
# correctly-shaped (NJT) zeros for the grad of the other output in this case.
(out1, _) = f(nt)
out1.backward(torch.ones_like(out1))

@dtypes(torch.float64, torch.float32, torch.half)
def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
values = torch.randn(10, 5, device=device, dtype=dtype)
Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,18 @@ static void _wrap_outputs(
PyTuple_SetItem(outputs, i, obj);
} else {
if (is_executable) {
// If one of the grad outputs is undefined, a correctly-shaped zeros
// should be used instead. To construct these for NJT, zeros_like() must
// be used until we have factory function support.
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
self->output_info.emplace_back(*wrapped_outputs[i]);
bool is_differentiable =
(non_differentiable.count(
wrapped_outputs[i]->unsafeGetTensorImpl()) == 0 &&
isDifferentiableType(wrapped_outputs[i]->scalar_type()));
bool use_zeros_like = is_differentiable && num_outputs > 1 &&
wrapped_outputs[i]->is_nested();
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
self->output_info.emplace_back(*wrapped_outputs[i], use_zeros_like);
}
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/autograd/variable_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,31 @@
#include <ATen/Functions.h>
#else
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif

#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/variable_info.h>

namespace torch::autograd {

VariableInfo::VariableInfo(const Variable& var)
VariableInfo::VariableInfo(const Variable& var, bool use_zeros_like)
: layout(var.layout()),
device(var.device()),
scalar_type(var.scalar_type()),
size(var.sym_sizes().vec()),
requires_grad(var.requires_grad()),
is_empty(false) {}
is_empty(false),
the_var(use_zeros_like ? std::optional<Variable>(var) : std::nullopt) {}

VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}

Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
if (is_empty) {
// Return undefined tensor.
return at::Tensor();
} else if (the_var.has_value()) {
return at::zeros_like(*the_var);
} else {
return at::zeros_symint(
size, at::TensorOptions(scalar_type).device(device).layout(layout));
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/autograd/variable_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace torch::autograd {

struct TORCH_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var);
explicit VariableInfo(const Variable& var, bool use_zeros_like = false);

Variable zeros(at::OptionalDeviceGuard& device_guard) const;

Expand All @@ -16,6 +16,8 @@ struct TORCH_API VariableInfo {
std::vector<c10::SymInt> size;
bool requires_grad;
bool is_empty;
// needed for e.g. NJTs since they only support zeros_like()
std::optional<Variable> the_var;
};

} // namespace torch::autograd

0 comments on commit 76ab1ab

Please sign in to comment.