From 6c5b0e1d7a7e9aebb2bc5305cd53e6a77fdeb2dd Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Fri, 17 Jan 2025 12:33:07 +0200 Subject: [PATCH 01/18] add cross_entropy to torchcompile_cat executor --- thunder/executors/torch_compile.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 5aaba1e64f..802ce3b1a5 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -216,6 +216,11 @@ def cuda_device_checker(*args, **kwargs): required_ops = { "torch.cat", prims.cat.id, + "nll_loss_backward", + "log_softmax_backward", + "torch.log_softmax", + "torch.nn.functional.nll_loss", + "torch.nn.functional.cross_entropy", } torch_compile_cat_ex = TorchCompileExecutor(name="torchcompile_cat", required_ops=required_ops) register_executor(torch_compile_cat_ex) @@ -238,7 +243,19 @@ def cuda_device_checker(*args, **kwargs): # parallel residual paths are used in the transformer block prims.div.id, prims.erf.id, + # Ops needed to support fusing HF causal LM Loss. + prims.where.id, + prims.ne.id, + prims.take_along_axis.id, + "torch.take_along_dim", + "torch.Tensor.contiguous", + "torch.log_softmax", + "torch.nn.functional.nll_loss", + "torch.nn.functional.cross_entropy", + "nll_loss_backward", + "log_softmax_backward", } + torch_compile_cat_ex._implmap = { op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops } From ecc1f7c0982d093a3474c7b6c26ea3d7ee4de8c1 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:13:55 +0200 Subject: [PATCH 02/18] Revert "add cross_entropy to torchcompile_cat executor" This reverts commit 6c5b0e1d7a7e9aebb2bc5305cd53e6a77fdeb2dd. --- thunder/executors/torch_compile.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 802ce3b1a5..5aaba1e64f 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -216,11 +216,6 @@ def cuda_device_checker(*args, **kwargs): required_ops = { "torch.cat", prims.cat.id, - "nll_loss_backward", - "log_softmax_backward", - "torch.log_softmax", - "torch.nn.functional.nll_loss", - "torch.nn.functional.cross_entropy", } torch_compile_cat_ex = TorchCompileExecutor(name="torchcompile_cat", required_ops=required_ops) register_executor(torch_compile_cat_ex) @@ -243,19 +238,7 @@ def cuda_device_checker(*args, **kwargs): # parallel residual paths are used in the transformer block prims.div.id, prims.erf.id, - # Ops needed to support fusing HF causal LM Loss. - prims.where.id, - prims.ne.id, - prims.take_along_axis.id, - "torch.take_along_dim", - "torch.Tensor.contiguous", - "torch.log_softmax", - "torch.nn.functional.nll_loss", - "torch.nn.functional.cross_entropy", - "nll_loss_backward", - "log_softmax_backward", } - torch_compile_cat_ex._implmap = { op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops } From 0927d743ef5e11f7e2285528ada90e91c910a305 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:14:58 +0200 Subject: [PATCH 03/18] move xentropy in its own executor --- thunder/executors/torch_compile.py | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 5aaba1e64f..e393fa35da 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -243,6 +243,42 @@ def cuda_device_checker(*args, **kwargs): op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops } +# Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow +# inductor to claim cross_entropy computation. +required_ops = { + prims.reshape.id, + "nll_loss_backward", + "log_softmax_backward", + "torch.log_softmax", + "torch.nn.functional.nll_loss", + "torch.nn.functional.cross_entropy", +} +torch_compile_xentropy = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops) +register_executor(torch_compile_xentropy) + +supported_ops = { + prims.broadcast_in_dim.id, + prims.convert_element_type.id, + prims.div.id, + prims.ne.id, + prims.neg.id, + prims.reshape.id, + prims.slice_prim.id, + prims.where.id, + "nll_loss_backward", + "log_softmax_backward", + "torch.log_softmax", + "torch.nn.functional.cross_entropy", + "torch.nn.functional.nll_loss", + "torch.sum", + "torch.take_along_dim", + "torch.Tensor.contiguous", +} + +torch_compile_xentropy._implmap = { + op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops +} + torch_compile_ex = TorchCompileExecutor(name="torchcompile") register_executor(torch_compile_ex) From e11ddab8eb582fc9a2a0ffcdb441f7c50a4f8aae Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:16:19 +0200 Subject: [PATCH 04/18] reshape not always required --- thunder/executors/torch_compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index e393fa35da..92922e50d7 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -246,7 +246,6 @@ def cuda_device_checker(*args, **kwargs): # Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow # inductor to claim cross_entropy computation. required_ops = { - prims.reshape.id, "nll_loss_backward", "log_softmax_backward", "torch.log_softmax", From d815892476e7f434829b29e15a15b2cec581c0c7 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:28:04 +0200 Subject: [PATCH 05/18] add torchcompile_xentropy to the tests --- thunder/tests/framework.py | 19 ++++++++++++++++++- thunder/tests/test_extend.py | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index f02d3c3c3b..31eab06388 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -208,6 +208,22 @@ def version(self): return torch.__version__ +class TorchCompileXentropyTestExecutor(TestExecutor): + name = "torchcompile_xentropy" + supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA) + supported_dtypes = (datatypes.dtype,) + + def is_available(self) -> bool: + return not IS_WINDOWS + + def executors_list(self) -> list[extend.Executor]: + from thunder.executors.torch_compile import torch_compile_cat_ex + + return [torch_compile_cat_ex] + + def version(self): + return torch.__version__ + class TorchCompileCatTestExecutor(TestExecutor): name = "torchcompile_cat" supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA) @@ -261,6 +277,7 @@ def make_callable(self, fn, **kwargs): # TODO Refactor these executors into the actual executor (sub)modules TorchExecutor: TorchTestExecutor = TorchTestExecutor() TorchCompileCatExecutor: TorchCompileCatTestExecutor = TorchCompileCatTestExecutor() +TorchCompileXentropyExecutor: TorchCompileXentropyTestExecutor = TorchCompileXentropyTestExecutor() TorchCompileExecutor: TorchCompileTestExecutor = TorchCompileTestExecutor() DynamoThunderExecutor: DynamoThunderTestExecutor = DynamoThunderTestExecutor() nvFuserExecutor: None | nvFuserTestExecutor = None @@ -368,7 +385,7 @@ def __init__( self.supported_executors = ( set(supported_executors) if supported_executors is not None - else set(_all_test_executors() + [TorchCompileCatExecutor]) + else set(_all_test_executors() + [TorchCompileCatExecutor, TorchCompileXentropyExecutor]) ) for ex in self.supported_executors: assert isinstance(ex, TestExecutor) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index c5a45ffd83..7e4a4b04b2 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -132,6 +132,7 @@ def test_get_all_executors_includes_all_native_executors(): "sdpa", "torchcompile", "torchcompile_cat", + "torchcompile_xentropy", "python", "transformer_engine", } From 39530d313fe412ddd314d14967cbcbd2ae0b302d Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:28:43 +0200 Subject: [PATCH 06/18] add torchcompile_xentropy as default executor --- thunder/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 1fcc346471..2ba53cdfc9 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -191,16 +191,18 @@ cudnn_executor: None | extend.Executor = extend.get_executor("cudnn") sdpa_executor: None | extend.Executor = extend.get_executor("sdpa") torchcompile_cat_executor: None | extend.Executor = extend.get_executor("torchcompile_cat") +torchcompile_xentropy_executor: None | extend.Executor = extend.get_executor("torchcompile_xentropy") apex_executor: None | extend.Executor = extend.get_executor("apex") nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser") pytorch_executor: None | extend.Executor = extend.get_executor("torch") -# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python] +# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python] # Note that add_default_executor inserts executor at start of list, hence the reverse order below. if nvfuser_executor: add_default_executor(nvfuser_executor) if torchcompile_cat_executor and pytorch._dynamo.is_inductor_supported(): + add_default_executor(torchcompile_xentropy_executor) add_default_executor(torchcompile_cat_executor) if sdpa_executor: From 284e7d99e511d49bb6897e7166e522df5c76ba89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jan 2025 15:29:30 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/framework.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index 31eab06388..8f458576f0 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -224,6 +224,7 @@ def executors_list(self) -> list[extend.Executor]: def version(self): return torch.__version__ + class TorchCompileCatTestExecutor(TestExecutor): name = "torchcompile_cat" supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA) From c4d98e9185c892f4f960155acc4d990f2ac64ff9 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:09:46 +0200 Subject: [PATCH 08/18] remove torch.nn.functional.cross_entropy from supported as it is in required --- thunder/executors/torch_compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 92922e50d7..6e74a4fa22 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -267,7 +267,6 @@ def cuda_device_checker(*args, **kwargs): "nll_loss_backward", "log_softmax_backward", "torch.log_softmax", - "torch.nn.functional.cross_entropy", "torch.nn.functional.nll_loss", "torch.sum", "torch.take_along_dim", From 9338118bb7e5ce5d1c012689fa8e5805953d5daa Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy <45029495+ali-alshaar7@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:48:38 -0500 Subject: [PATCH 09/18] bump version (#1708) --- thunder/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/__about__.py b/thunder/__about__.py index 4e812d79f6..95164816a2 100644 --- a/thunder/__about__.py +++ b/thunder/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.0dev" +__version__ = "0.2.1dev" __author__ = "Lightning-AI et al" __author_email__ = "community@lightning.ai" __license__ = "Apache 2.0" From c875621588b9ba80a01e187e631dcaf77a48a8d3 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 28 Jan 2025 22:36:31 +0100 Subject: [PATCH 10/18] nvfuser: add option to allow shape only region (#1702) --- thunder/executors/nvfuserex_impl.py | 16 +++++++++++----- thunder/tests/test_nvfuser.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 11237e4c84..0399c1afcf 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -886,11 +886,17 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: else: bookend_result = {"front_bsyms": [], "fusion": region, "rear_bsyms": []} - # Don't fuse a region which has only Shape Operations. - all_shape_ops = all(map(lambda bsym: all_tagged(bsym, prims.OpTags.SHAPE_OP), bsyms)) - if all_shape_ops: - fused_bsyms.extend(bsyms) - continue + nv_enable_shape_only_fusion: None | bool = get_compile_option( + "nv_enable_shape_only_fusion", + "Allow nvFuser to create Fusion with shape only operations. Defaults to False.", + ) + + if not nv_enable_shape_only_fusion: + # Don't fuse a region which has only Shape Operations. + all_shape_ops = all(map(lambda bsym: all_tagged(bsym, prims.OpTags.SHAPE_OP), bsyms)) + if all_shape_ops: + fused_bsyms.extend(bsyms) + continue if len(bsyms) == 1: bsym: BoundSymbol = bsyms[0] diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 2926079cf5..a1a8897d95 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -1181,14 +1181,21 @@ def fn(a, b): dtypes=(thunder.float32,), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,), + decorators=(pytest.mark.parametrize("nv_enable_shape_only_fusion", [True, False, None]),), ) -def test_no_shape_only_fusion_region(executor, device: str, thunder_dtype: dtypes.dtype): +def test_no_shape_only_fusion_region( + executor, device: str, thunder_dtype: dtypes.dtype, nv_enable_shape_only_fusion: bool +): x = make_tensor(2, 2, 2, device=device, dtype=ltorch.to_torch_dtype(thunder_dtype)) def fn(x): return x.view(4, -1).transpose(0, 1) - jfn = thunder.jit(fn) + if nv_enable_shape_only_fusion is None: + options_dict = {} + else: + options_dict = {"nv_enable_shape_only_fusion": nv_enable_shape_only_fusion} + jfn = thunder.jit(fn, **options_dict) expected = fn(x) actual = jfn(x) @@ -1197,8 +1204,11 @@ def fn(x): fwd_trace = thunder.last_traces(jfn)[-1] - # Make sure there are no fusion symbols. - assert all(not bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols) + if nv_enable_shape_only_fusion: + assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols) + else: + # Make sure there are no fusion symbols. + assert all(not bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols) # Verify that we create fusion even if we have a single compute op. def fn(x): From 7e857f5bd62f3fabe8c2712b618cf3844177fcc8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 28 Jan 2025 13:37:47 -0800 Subject: [PATCH 11/18] Backward transform dependency fix (#1693) --- thunder/common.py | 6 ++- thunder/core/prims.py | 3 ++ thunder/core/proxies.py | 2 +- thunder/core/utils.py | 19 ++++++++-- thunder/core/vjp_utils.py | 4 +- thunder/tests/test_transforms.py | 63 ++++++++++++++++++++++++++++++++ 6 files changed, 91 insertions(+), 6 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index eb49d511c2..de5a6138bc 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -392,9 +392,13 @@ def translate(x: Any, *, name: str | None = None) -> Any: # TODO Update cacheable types def _make_subkey_for(x: Any) -> tuple[bool, None | tuple]: - if isinstance(x, (torch.Tensor, TensorProxy)): + if isinstance(x, torch.Tensor): return True, (type(x), x.shape, x.device, x.dtype, x.requires_grad) + if isinstance(x, TensorProxy): + # calling _shape instead shape to avoid leaving a prims.shape in the trace + return True, (type(x), x._shape, x.device, x.dtype, x.requires_grad) + # TODO Add NumPy ndarray support if isinstance(x, np.ndarray): return False, None diff --git a/thunder/core/prims.py b/thunder/core/prims.py index e8dc1861a0..f65abd2a0f 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1814,10 +1814,13 @@ def _put_grad_meta(grad_for: Number | NumberProxy | TensorProxy, grad: Number | return None +# PUT_GRAD is a sink node with side effects that updates Tensor.grad. It needs +# DONT_DCE tag to avoid removal in DCE pass. put_grad = make_prim( PrimIDs.PUT_GRAD, "put_grad", meta=_put_grad_meta, + tags=(OpTags.DONT_DCE,), ) # diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index e994c23e09..36a747a0c0 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1219,7 +1219,7 @@ def _infer_tensor_properties( if like is not None: baseutils.check_type(like, (TensorProxy, FutureTensorProxy)) - _shape = tuple(like.shape) + _shape = tuple(like._shape) _device = like.device _dtype = like.true_dtype _requires_grad = like.requires_grad diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 8eab3c8f3f..dc4e6f345d 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1097,7 +1097,8 @@ def find_producer_symbols(trace: TraceCtx, proxies: Sequence[Proxy], stop_proxie stop_proxies: proxies to stop at Returns: - tuple of symbols that produce the given proxies + tuple of symbols that produce the given proxies. In case of duplicate bound_symbols in trace, e.g. prims.shape. + The returned tuple would drop the duplicates and only preserve the first encounter. Example: >>> import torch @@ -1128,11 +1129,23 @@ def find_producer_symbols(trace: TraceCtx, proxies: Sequence[Proxy], stop_proxie if p is not None: result.add(p) for arg in p.flat_args: - arg_name = arg.name if isinstance(arg, Proxy) else None + if not isinstance(arg, Proxy): + continue + arg_name = arg.name if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen: queue.append(arg) seen.add(arg_name) - original_order = {bsym: i for i, bsym in enumerate(trace.bound_symbols)} + # original_order maps from bound_symbol to the index/order of its occurence in the trace. The order is + # used to sort producer bound symbols to preserve the correctness of data dependency. + original_order = dict() + for i, bsym in enumerate(trace.bound_symbols): + # Don't overwrite the order if it's already encountered. This is necessary for duplicate bsyms. + # e.g. duplicate shape queries. By preserving the smaller index, we ensure that the re-ordered + # shape queies would be placed before any consumers of its outputs, hence preserving the correctness + # of data dependency. + if bsym in original_order: + continue + original_order[bsym] = i return tuple(sorted(result, key=lambda x: original_order[x])) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index b6f3712e5e..686fedbab8 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -48,6 +48,7 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable import thunder from thunder.common import _make_cache_key from thunder.core.transforms import _get_gradfn_and_executor, eval_trace + from thunder.core.transform_common import dce joint_forward_backward, executor = _get_gradfn_and_executor(bsym) utils.check( @@ -59,7 +60,8 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): return cached_result - joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) + # dce is necessary to remove duplicated shape queries, otherwise the trace might overwritten NumberProxy variables + joint_trace = thunder.trace(inline_trace=False, use_dce=True)(joint_forward_backward, *bsym.args, **bsym.kwargs) consumers = utils.consumers(joint_trace) def find_backward_input(forward_output): diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index e1bbb24c40..666f8e05b9 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -792,3 +792,66 @@ def _count_shape_query(trace): # dce should remove duplicate shape queries trace = thunder.core.transforms.dce(trace) assert _count_shape_query(trace) == 1 + + +def test_cache_symbolic_values_grad_matmul(): + def foo(a, w): + return torch.nn.functional.linear(a, w) + + jfoo = thunder.jit(foo, cache="symbolic values") + set_requires_grad = lambda x: x.requires_grad_() + + a = torch.randn(2, 8, 6) + b = torch.randn(4, 6) + a_ref = a.clone() + b_ref = b.clone() + for x in (a, b, a_ref, b_ref): + x.requires_grad_() + actual = jfoo(a, b) + expected = foo(a_ref, b_ref) + actual.sum().backward() + expected.sum().backward() + + assert_close(actual, expected) + assert_close(a.grad, a_ref.grad) + assert_close(b.grad, b_ref.grad) + assert thunder.cache_misses(jfoo) == 1 + assert thunder.cache_hits(jfoo) == 0 + + a = torch.randn(4, 4, 2) + b = torch.randn(8, 2) + a_ref = a.clone() + b_ref = b.clone() + for x in (a, b, a_ref, b_ref): + x.requires_grad_() + actual = jfoo(a, b) + expected = foo(a_ref, b_ref) + actual.sum().backward() + expected.sum().backward() + + assert_close(actual, expected) + assert_close(a.grad, a_ref.grad) + assert_close(b.grad, b_ref.grad) + assert thunder.cache_misses(jfoo) == 1 + assert thunder.cache_hits(jfoo) == 1 + + +def test_cache_symbolic_values_grad_unsqueeze(): + def foo(x): + cache = torch.arange(0, 128, 1) + cache_unsqueezed = cache.unsqueeze(0) + return x + cache_unsqueezed + + jfoo = thunder.jit(foo, cache="symbolic values") + set_requires_grad = lambda x: x.requires_grad_() + + a = torch.randn(2, 8, 128) + a_ref = a.clone() + for x in (a, a_ref): + x.requires_grad_() + actual = jfoo(a) + expected = foo(a_ref) + actual.sum().backward() + expected.sum().backward() + assert_close(actual, expected) + assert_close(a.grad, a_ref.grad) From fcec023f3c0eecbd92c648784280ec3869e20fde Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 28 Jan 2025 22:49:22 +0100 Subject: [PATCH 12/18] pin check schema (#1709) --- .github/workflows/ci-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index e183fbb169..fcfff94c55 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -16,7 +16,7 @@ jobs: python-version: "3.10" check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@68c9200d341fe32ad7f5d7d1d46ae67f2f93836d with: azure-dir: ".azure" From 9a40bb00c856ee24a0f427873a5a8a2a8c2729ce Mon Sep 17 00:00:00 2001 From: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:54:56 +0200 Subject: [PATCH 13/18] bump transformers version (#1698) --- requirements/test.txt | 2 +- thunder/tests/test_networks.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index b1dd683eee..465ed20cea 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -18,7 +18,7 @@ pandas # thunder/benchmarks/test_benchmark_litgpt.py xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py jsonargparse # thunder/benchmarks/benchmark_litgpt.py bitsandbytes==0.42.0 # fixed version! -transformers==4.46.3 # for test_networks.py +transformers==4.48.1 # for test_networks.py # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 788ea29992..1bad432a2e 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -508,6 +508,11 @@ def test_hf_for_nemo(model_id): @requiresCUDA def test_hf_llama(): + import transformers + + if version_between(transformers.__version__, min_ver="4.46.4"): + pytest.skip("Dynamic cache is not supported, see static cache 'test_hf_kvcache'") + from transformers.models.llama import LlamaForCausalLM, LlamaConfig from transformers.models.llama.modeling_llama import logger as llama_logger from thunder.examine import get_fusion_symbols From 1df59292d1f038c66e4476507fae80d9c89ca6f1 Mon Sep 17 00:00:00 2001 From: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:01:23 +0200 Subject: [PATCH 14/18] Reduce `test_thunderfx_mistral_nemo_small` model size (#1701) Reducing the models size to speedup CI and make it work in environments with constrained memory size. I think this change should be fine since with this test we are verifying functionality more than memory footprint. Let me know what you think. With this change the peak memory of the test is ~2.6GB instead of ~14GB --- thunder/tests/test_networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 1bad432a2e..cff47fbdd5 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -392,14 +392,14 @@ def test_thunderfx_mistral_nemo_small(): config = transformers.models.mistral.configuration_mistral.MistralConfig( num_hidden_layers=1, torch_dtype=torch.bfloat16, - max_position_embeddings=1024, + max_position_embeddings=64, architectures=["MistralForCausalLM"], - hidden_size=5120, + hidden_size=1024, rms_norm_eps=1e-05, rope_theta=1000000.0, sliding_window=None, vocab_size=131072, - head_dim=128, + head_dim=32, _name_or_path=model_id, ) model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=False) From 71bd489ea592580c4fa1c34a9701c040711eae42 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:14:24 +0200 Subject: [PATCH 15/18] fix naming --- thunder/executors/torch_compile.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 6e74a4fa22..0b4bed09d1 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -252,8 +252,8 @@ def cuda_device_checker(*args, **kwargs): "torch.nn.functional.nll_loss", "torch.nn.functional.cross_entropy", } -torch_compile_xentropy = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops) -register_executor(torch_compile_xentropy) +torch_compile_xentropy_ex = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops) +register_executor(torch_compile_xentropy_ex) supported_ops = { prims.broadcast_in_dim.id, @@ -273,7 +273,7 @@ def cuda_device_checker(*args, **kwargs): "torch.Tensor.contiguous", } -torch_compile_xentropy._implmap = { +torch_compile_xentropy_ex._implmap = { op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops } From d94a4205bc4439d849c3a650505875ed6e8b3420 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:20:18 +0200 Subject: [PATCH 16/18] add pad prims --- thunder/executors/torch_compile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 0b4bed09d1..5f4532dbce 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -261,6 +261,7 @@ def cuda_device_checker(*args, **kwargs): prims.div.id, prims.ne.id, prims.neg.id, + prims.pad.id, prims.reshape.id, prims.slice_prim.id, prims.where.id, From 35c4606278090ad2637b68165ad435fd65abb487 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:21:46 +0200 Subject: [PATCH 17/18] add test for the new executor --- thunder/tests/test_torch_compile_executor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index c388dd6347..cb80b4bf6b 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -122,3 +122,20 @@ def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: out_jitted = forward_and_loss_jitted(model, input_ids) assert_close(out, out_jitted) + + +@requiresCUDA +def test_torch_compile_xentropy_loss(): + from transformers.loss.loss_utils import ForCausalLMLoss + + logits = torch.randn(1, 2, 6, device="cuda", requires_grad=True) + labels = torch.randint(0, 6, (1, 2), device="cuda") + vocab_size = 6 + + closs_fn = thunder.jit(ForCausalLMLoss, executors=[torch_compile_xentropy_ex]) + _ = closs_fn(logits, labels, vocab_size, ignore_index=-1) + forward_trace = thunder.last_traces(closs_fn)[-1].python() + + # make a single torch.compile region + assert "TorchCompile0" in forward_trace + assert "TorchCompile1" not in forward_trace From 3edc0a367f449f19efdf580faea9e0d98a8a29cf Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:57:28 +0200 Subject: [PATCH 18/18] add missing import --- thunder/tests/test_torch_compile_executor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index cb80b4bf6b..ce66d75d3d 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -4,7 +4,12 @@ from torch._dynamo import is_inductor_supported import thunder -from thunder.executors.torch_compile import supported_ops, torch_compile_ex, torch_compile_cat_ex +from thunder.executors.torch_compile import ( + supported_ops, + torch_compile_ex, + torch_compile_cat_ex, + torch_compile_xentropy_ex, +) from thunder.executors.torchex import ex as pytorch_ex from thunder.executors.nvfuserex import nvfuserex from thunder.tests.bf16 import device_supports_bf16