Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchcompile_xentropy executor #1655

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,18 @@
cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
torchcompile_cat_executor: None | extend.Executor = extend.get_executor("torchcompile_cat")
torchcompile_xentropy_executor: None | extend.Executor = extend.get_executor("torchcompile_xentropy")
apex_executor: None | extend.Executor = extend.get_executor("apex")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")

# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python]
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python]
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)

if torchcompile_cat_executor and pytorch._dynamo.is_inductor_supported():
add_default_executor(torchcompile_xentropy_executor)
add_default_executor(torchcompile_cat_executor)

if sdpa_executor:
Expand Down
35 changes: 35 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,41 @@ def cuda_device_checker(*args, **kwargs):
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}

# Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow
# inductor to claim cross_entropy computation.
required_ops = {
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.nll_loss",
"torch.nn.functional.cross_entropy",
}
torch_compile_xentropy = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops)
register_executor(torch_compile_xentropy)

supported_ops = {
prims.broadcast_in_dim.id,
prims.convert_element_type.id,
prims.div.id,
prims.ne.id,
prims.neg.id,
prims.reshape.id,
prims.slice_prim.id,
prims.where.id,
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.cross_entropy",
"torch.nn.functional.nll_loss",
"torch.sum",
"torch.take_along_dim",
"torch.Tensor.contiguous",
}

torch_compile_xentropy._implmap = {
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}


torch_compile_ex = TorchCompileExecutor(name="torchcompile")
register_executor(torch_compile_ex)
Expand Down
20 changes: 19 additions & 1 deletion thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def version(self):
return torch.__version__


class TorchCompileXentropyTestExecutor(TestExecutor):
name = "torchcompile_xentropy"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
supported_dtypes = (datatypes.dtype,)

def is_available(self) -> bool:
return not IS_WINDOWS

def executors_list(self) -> list[extend.Executor]:
from thunder.executors.torch_compile import torch_compile_cat_ex

return [torch_compile_cat_ex]

def version(self):
return torch.__version__


class TorchCompileCatTestExecutor(TestExecutor):
name = "torchcompile_cat"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
Expand Down Expand Up @@ -261,6 +278,7 @@ def make_callable(self, fn, **kwargs):
# TODO Refactor these executors into the actual executor (sub)modules
TorchExecutor: TorchTestExecutor = TorchTestExecutor()
TorchCompileCatExecutor: TorchCompileCatTestExecutor = TorchCompileCatTestExecutor()
TorchCompileXentropyExecutor: TorchCompileXentropyTestExecutor = TorchCompileXentropyTestExecutor()
TorchCompileExecutor: TorchCompileTestExecutor = TorchCompileTestExecutor()
DynamoThunderExecutor: DynamoThunderTestExecutor = DynamoThunderTestExecutor()
nvFuserExecutor: None | nvFuserTestExecutor = None
Expand Down Expand Up @@ -368,7 +386,7 @@ def __init__(
self.supported_executors = (
set(supported_executors)
if supported_executors is not None
else set(_all_test_executors() + [TorchCompileCatExecutor])
else set(_all_test_executors() + [TorchCompileCatExecutor, TorchCompileXentropyExecutor])
)
for ex in self.supported_executors:
assert isinstance(ex, TestExecutor)
Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_get_all_executors_includes_all_native_executors():
"sdpa",
"torchcompile",
"torchcompile_cat",
"torchcompile_xentropy",
"python",
"transformer_engine",
}
Expand Down
Loading