Skip to content

Commit

Permalink
fix: handle buffer setattr at root module (#1692)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-alshaar7 authored Jan 27, 2025
1 parent 98f286c commit 837146e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
12 changes: 8 additions & 4 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,10 +1824,14 @@ def process_recorded_modifications(ctx, epilogue_trace):
and modified_object.provenance.inputs[1].value == "_buffers"
):
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) # todo: better criterion
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
if modified_object.provenance.inputs[0].inst is PseudoInst.INPUT_FN:
name = [""]
root_module_provenance = modified_object.provenance.inputs[0]
else:
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
# we want this to created in the compute trace context for namespace...
Expand Down
17 changes: 1 addition & 16 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,27 +710,12 @@ def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch
# update the cache
return self.k, self.v

# BUG: issue: 1637
class ParentModule(nn.Module):
def __init__(
self,
k_shape: tuple[int, int, int, int],
v_shape: tuple[int, int, int, int],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)

def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.cast_module(k, v)

with torch.device("cpu"):
k_shape = (2, 3, 4, 5)
v_shape = (2, 3, 4, 5)
device = torch.device("cpu")
dtype = torch.float32
model = ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False)
model = cast(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False)

k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
Expand Down

0 comments on commit 837146e

Please sign in to comment.