-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
thunderfx : detecting parameters and buffers on thunderfx path #1575
Comments
kshitij12345
added
jit
thunderfx
for things that could be applicable to the dynamo+thunder frontend
labels
Dec 19, 2024
Example of ExtrationOnlyPrologueTransform not working import torch
import thunder
import thunder.transforms
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
from thunder.dynamo import thunderfx
model = torch.nn.Linear(16, 16)
x = torch.randn(16, 16)
cmodel = thunderfx(model, transforms=[ExtractionOnlyPrologueTransform()])
_ = cmodel(x)
assert len(cmodel._backend.subgraph_infos) == 1
subgraph_info = cmodel._backend.subgraph_infos[0]
thunder_fn = subgraph_info.thunder_compiled_fns[0]
original_subgraph = subgraph_info.original_graph_module
prl_trc = thunder.last_prologue_traces(thunder_fn)[-1]
with open("prologue_trc.py", "w") as f:
f.write(str(prl_trc)) Prologue Trace # Constructed by Transform for execution (took 3 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
# args: "Any"
check_len(args, 3)
# prims.check_len(args, 3)
# kwargs: "Any"
check_len(kwargs, 0)
# prims.check_len(kwargs, 0)
l_args_0_: "cpu f32[16, 16]" = args[0]
l_fn_parameters_weight_: "cpu f32[16, 16]" = args[1]
l_fn_parameters_bias_: "cpu f32[16]" = args[2]
check_tensor_metadata(l_args_0_, (16, 16), 'cpu', torch.float32, False)
# prims.check_tensor_shape_and_metadata(l_args_0_, (16, 16), 'cpu', torch.float32, False)
check_tensor_metadata(l_fn_parameters_weight_, (16, 16), 'cpu', torch.float32, True)
# prims.check_tensor_shape_and_metadata(l_fn_parameters_weight_, (16, 16), 'cpu', torch.float32, True)
check_tensor_metadata(l_fn_parameters_bias_, (16,), 'cpu', torch.float32, True)
# prims.check_tensor_shape_and_metadata(l_fn_parameters_bias_, (16,), 'cpu', torch.float32, True)
cache_info: "Any" = thunder._get_cache_info()
cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
check_literal_like(cache_info_default_dtype, torch.float32)
# prims.check_literal_like(cache_info_default_dtype, torch.float32)
cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
check_literal_like(cache_info_default_device, torch.device("cpu"))
# prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
check_number_type_and_value(cache_info_is_autocast_enabled, False)
# prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
check_number_type_and_value(cache_info_no_grad_sync, False)
# prims.check_number_type_and_value(cache_info_no_grad_sync, False)
cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
check_string_value(cache_info_alias_tensor_indices, '')
# prims.check_string_value(cache_info_alias_tensor_indices, '')
cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
check_number_type_and_value(cache_info_is_grad_enabled, True)
# prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
return ((l_args_0_, l_fn_parameters_weight_, l_fn_parameters_bias_), ()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The FXGraph provided by Dynamo takes in Parameters and Buffers as arguments, however
thunder.jit
currently only determines a TensorProxy to be a parameter if it is unpacked from a Module. So, on thunderfx path, we don't tag these parameters with STATIC_MEMORY_LOCATION, leading to problem with CUDAGraphTransform and ExtraionOnlyPrologueTransform which depend on these tags.lightning-thunder/thunder/core/jit_ext.py
Lines 1493 to 1500 in 35ca2e9
Potential Solution for Parameters : For parameters, maybe
thunder.jit
tag Proxies based onisinstance(obj, nn.Parameter)
.Sample:
Output
The text was updated successfully, but these errors were encountered: