Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot turn off rematerialization for an individual lang op with hand specified grad transform #1658

Open
kevinstephano opened this issue Jan 17, 2025 · 4 comments

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Jan 17, 2025

I am trying to not recompute torch.nn.functional.dropout because recomputing the random numbers can be expensive relative to the fusion it appears in. Therefore, I am trying to selectively not rematerialize the computation for dropout as well as make a custom grad function such that only a byte mask is saved from forward to backward. Without the custom grad function, the autograd transform will save the random float numbers instead of the byte mask.

🐛 Bug

To Reproduce

Patch to apply:

diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py
index 37ebd5a1..0a11772d 100644
--- a/thunder/core/rematerialization.py
+++ b/thunder/core/rematerialization.py
@@ -755,6 +755,7 @@ def rematerialize_forward_and_backward(fw_trace: TraceCtx, bw_trace: TraceCtx) -


 def replace_uniform(trace: TraceCtx) -> TraceCtx:
+    return trace
     """For better rematerialization, replace the uniform operator with the stateless uniform_philox operator and manually update the RNG state."""
     start_time_ns = time.perf_counter_ns()
     from thunder.core.trace import VariableInterface

Patch to apply:

diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 9deed5fe..5676eead 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -1460,6 +1460,49 @@ def _log_sigmoid_grad(

 register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad)

+def _dropout_grad(a: TensorProxy, /, p: NumberLike = 0.5, training: bool = True, inplace: bool = False) -> TensorProxy:
+    #assert False, "I am here!"
+
+    if inplace:
+        raise NotImplementedError("Only inplace=False is currently supported in dropout")
+
+    if not training:
+        fwd = a
+        g = get_grad(fwd)
+        put_grad(a, g)
+        return fwd
+
+    utils.check(
+        p <= 1 and p >= 0,
+        lambda: f"Dropout probability has to be between 0 and 1, but got, {p}",
+    )
+
+    fwd: TensorProxy
+
+    if p == 1:
+        fwd = zeros_like(a)
+        put_grad(a, zeros_like(a))
+        return fwd
+    if p == 0:
+        fwd = a
+        g = get_grad(fwd)
+        put_grad(a, g)
+        return fwd
+
+    scale = 1 / (1 - p)
+    r = clang.uniform_like(a, 0.0, 1.0)
+    dropout_mask = r < (1 - p)
+
+    scaled_dropout_mask = scale * dropout_mask
+    fwd = a * scaled_dropout_mask
+
+    g = get_grad(fwd)
+    put_grad(a, g * scaled_dropout_mask)
+
+    return fwd
+
+register_grad("torch.nn.functional.dropout", _dropout_grad)
+

 #
 # Phantom grad transform helpers
import torch
import thunder

class TestMod(torch.nn.Module):
    def __init__(self):
        super(TestMod, self).__init__()
        # Most implementations also include some dropout
        self.dropout = torch.nn.Dropout(p=0.9)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(x)

inputs = [torch.randn(1024, 1024, device='cuda', requires_grad=True),]
grads = torch.randn(1024, 1024, device='cuda', requires_grad=False),

model = TestMod()
#model = thunder.jit(model, enable_saved_for_backward_recomputation=False)
model = thunder.jit(model)

out = model(*inputs)
out.backward(grads)

fwd_trace1 = thunder.last_traces(model)[0]
fwd_trace2 = thunder.last_traces(model)[-1]
bwd_trace1 = thunder.last_backward_traces(model)[0]
bwd_trace2 = thunder.last_backward_traces(model)[-1]

print("FORWARD TRACE")
print(fwd_trace1)
print(fwd_trace2)
print("BACKWARD TRACE")
print(bwd_trace1)
print(bwd_trace2)

If I look at the result, I see that the rematerialization happens even in the presence of a register_grad function:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  t9, = C0
  clear_mutable_collection(C0)
  del C0
  [bw_t1] = nvFusion0(t9, t0)
    # bw_t10 = prims.lt(t9, 0.09999999999999998)  # bw_t10: "cuda:0 b8[1024, 1024]"
    # bw_t11 = prims.convert_element_type(bw_t10, dtypes.float32_)  # bw_t11: "cuda:0 f32[1024, 1024]"
    # bw_t12 = prims.mul(10.000000000000002, bw_t11)  # bw_t12: "cuda:0 f32[1024, 1024]"
    # bw_t1 = prims.mul(t0, bw_t12)  # bw_t1: "cuda:0 f32[1024, 1024]"
  del t9, t0
  return (bw_t1,)

If I add enable_saved_for_backward_recomputation=False to my thunder.jit call. Then rematerialization does not take over:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  t10, = C0
  clear_mutable_collection(C0)
  del C0
  [bw_t1] = nvFusion0(t10, t0)
    # t11 = prims.convert_element_type(t10, dtypes.float32_)  # t11: "cuda:0 f32[1024, 1024]"
    # t12 = prims.mul(10.000000000000002, t11)  # t12: "cuda:0 f32[1024, 1024]"
    # bw_t1 = prims.mul(t0, t12)  # bw_t1: "cuda:0 f32[1024, 1024]"
  del t10, t0
  return (bw_t1,)

I would like to be able to selectively tag dropout to say just don't rematerialize this op. However, adding the tag does not work.

diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py
index 0d10f035..3eeb09df 100644
--- a/thunder/torch/__init__.py
+++ b/thunder/torch/__init__.py
@@ -4659,7 +4659,7 @@ def _cross_entropy_loss_label_smoothing(
 # TODO Move this to nn.functional
 # NOTE The id must be explicitly specified so as not to resolve to torch.dropout
 #   (Using torch.nn.functional.dropout is just for readability as it's the documented operator)
-@torchsymbol(torch.nn.functional.dropout, id="torch.nn.functional.dropout")
+@torchsymbol(torch.nn.functional.dropout, id="torch.nn.functional.dropout", tags=(prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD,))
 def dropout(a: TensorProxy, /, p: NumberLike = 0.5, training: bool = True, inplace: bool = False) -> TensorProxy:
     if inplace:
         raise NotImplementedError("Only inplace=False is currently supported in dropout")

cc @t-vi @riccardofelluga

@t-vi
Copy link
Collaborator

t-vi commented Jan 18, 2025

I'm a bit confused here, the randomness is not actually recomputed?

That said, let's disable the auto-recomputation of intermediates for now if it causes more problems than it solves.

@IvanYashchuk
Copy link
Collaborator

Changing grad definitions has little effect on the generated computation function because what to compute and how to compute are separate in Thunder. Grad definitions specify only the "what" part. Many trace execution transformations define the "how" part. The recomputation/rematerialization is part of the "how" transformation.

Setting enable_saved_for_backward_recomputation=False is forcing Thunder to behave like e536ddc was not merged. It's the correct first step for what you're trying to achieve.

The second part to not recomputing dropout is disabling (or making it an option) the "replace_uniform" transformation as you provide in the patch above.

With these two things, the dropout mask in b8 dtype is saved for backward. Is this what you need?
#1660 allows disabling replace_uniform and #1659 reverts the recent change of using recompute_saved_for_backward (different from fusion rematerialization!) by default.

Returning to the request in the issue title, it's complicated and more involved to do what you're asking for. Dropout is a composite operation in Thunder represented and by the time fusion rematerialization pass gets the traces dropout is gone and only prims decomposition is available with no reference to non-decomposed ops.

@IvanYashchuk
Copy link
Collaborator

With #1661 merged enable_saved_for_backward_recomputation flag has no effect on the resulting fusions.
Only the disable_replace_uniform=True/False added in #1660 determines whether a random mask is recomputed or saved.

@mruberry
Copy link
Collaborator

I think there are a few things we should follow-up on from @kevinstephano's example above:

  1. We should consider a mechanism to let custom autograd formulas be more definitive and prevent the application of later transforms locally
  2. We should review the pattern for operators that return an input in the fwd
  3. We may want to add an explicit "save_for_backward" or "save_tensor_for_backward" operation that acts as an assertion when a tensor not explicitly saved for backward by the code ends up used in backward

We should probably discuss more over a VC

@kevinstephano kevinstephano changed the title Cannot turn of rematerialization for an individual lang op with hand specified grad transform Cannot turn off rematerialization for an individual lang op with hand specified grad transform Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants