Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prims.div - Fails with torch_compile executor and incorrect torchex implementation #1695

Closed
kshitij12345 opened this issue Jan 24, 2025 · 4 comments · Fixed by #1697
Closed
Labels
bug Something isn't working

Comments

@kshitij12345
Copy link
Collaborator

import torch
import thunder
from thunder.executors.torch_compile import torch_compile_ex

def fn(x, y):
    return thunder.prims.div(x, y) + x

jfn = thunder.jit(fn, executors=[torch_compile_ex,])
# jfn = thunder.jit(fn)  # Works
jfn(torch.ones(3, 3), torch.ones(3, 3))
print(thunder.last_traces(jfn))
Error
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable(floating)] {}

from user code:
   File "thunder.to_be_compiled_3", line 9, in to_be_compiled
    t2 = torch_div_prim_impl(x, y)  # t2: "cpu f32[3, 3]"
  File "/home/kkalambarkar/lightning-thunder/thunder/executors/torchex.py", line 931, in _div_prim_impl
    if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):
  File "/home/kkalambarkar/lightning-thunder/thunder/core/dtypes.py", line 468, in is_integer_dtype
    return dtype in integer_dtypes

Found while running

import torch
from thunder.dynamo import thunderfx
from thunder.dev_utils.utils import _benchmark_fusion_region_with_nvfuser_and_torch_compile
from transformers import AutoConfig, AutoModelForCausalLM

model_id = "microsoft/Phi-3-mini-128k-instruct"
configuration = AutoConfig.from_pretrained(
    model_id,
    # Scaled down for testing
    vocab_size=16,
    pad_token_id=15,
    max_position_embeddings=32,
    num_hidden_layers=1,
)
configuration.hidden_size = configuration.num_attention_heads
with torch.device("cuda"):
    model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16)

model = thunderfx(model, nv_store_fusion_inputs=True)

input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
model(input_ids=input_ids, labels=input_ids)

trc = model.last_traces[0]

for bsym in trc.bound_symbols:
    if bsym.sym.is_fusion and "nvFusion" in bsym.sym.name:
        benchmark_comparison_data = _benchmark_fusion_region_with_nvfuser_and_torch_compile(bsym)
@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Jan 24, 2025

Also, there seems to be a bug in _div_prim_impl where we check dtypes.is_exact_dtype(to_dtype(a.dtype)) twice, I think other should check b to be Integer.

def _div_prim_impl(a: Number | torch.Tensor, b: Number | torch.Tensor) -> torch.Tensor:
if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):

Also, there is a bug with handling of Number and Tensor inputs in _div_prim_impl.

import torch
import thunder
from thunder.executors.torchex import ex as torchex

def fn(x, y):
    return thunder.prims.div(x, y) + x

jfn = thunder.jit(fn, executors=[torchex,])
jfn(1., torch.ones(3, 3))

Error

File "/home/kkalambarkar/lightning-thunder/thunder/executors/torchex.py", line 932, in _div_prim_impl
    if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):
AttributeError: 'float' object has no attribute 'dtype'

@kshitij12345 kshitij12345 added the bug Something isn't working label Jan 24, 2025
@kshitij12345 kshitij12345 changed the title prims.div - Fails with torch_compile executor prims.div - Fails with torch_compile executor and incorrect torchex implementation Jan 24, 2025
@mruberry
Copy link
Collaborator

Good catch on the torch executor's error in _div_prim_impl! That should be a straightforward fix.

Separately, I'm not sure if the torch compile executor supports calling the primitive directly like this. We should detect and prevent the torch compile executor from executing the prims.div call.

@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Jan 27, 2025

I'm not sure if the torch compile executor supports calling the primitive directly like this. We should detect and prevent the torch compile executor from executing the prims.div call.

It does work on other primitives like add. div only fails because _div_prim_impl tries to convert to thunder dtype and do queries on dtype (like if it is int or not) which torch.compile doesn't like. Doing the same operations in torch with torch.dtype works fine.

Currently torch_compile supports every symbol supported by torchex. I think this is valid and shouldn't be a problem in general.

torch_compile_ex = TorchCompileExecutor(name="torchcompile")
register_executor(torch_compile_ex)
torch_compile_ex._implmap = {op: ImplInfo() for op in pytorch_ex.implmap}

@mruberry
Copy link
Collaborator

I'm not sure if the torch compile executor supports calling the primitive directly like this. We should detect and prevent the torch compile executor from executing the prims.div call.

It does work on other primitives like add. div only fails because _div_prim_impl tries to convert to thunder dtype and do queries on dtype (like if it is int or not) which torch.compile doesn't like. Doing the same operations in torch with torch.dtype works fine.

Good point; PR approved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants