diff --git a/docs/source/torch.compiler_aot_inductor.rst b/docs/source/torch.compiler_aot_inductor.rst index 257f16f40cc05b..e00101d241b4c4 100644 --- a/docs/source/torch.compiler_aot_inductor.rst +++ b/docs/source/torch.compiler_aot_inductor.rst @@ -185,3 +185,15 @@ display results akin to the following: 0.4883 0.4703 [ CUDAFloatType{2,1} ] + + +Troubleshooting +--------------------------- +Below are some useful tools for debugging AOT Inductor. + +.. toctree:: + :caption: Debugging Tools + :maxdepth: 1 + + logging + torch.compiler_aot_inductor_minifier diff --git a/docs/source/torch.compiler_aot_inductor_minifier.rst b/docs/source/torch.compiler_aot_inductor_minifier.rst new file mode 100644 index 00000000000000..6cfb420961a860 --- /dev/null +++ b/docs/source/torch.compiler_aot_inductor_minifier.rst @@ -0,0 +1,213 @@ +AOTInductor Minifier +=========================== + +If you encounter an error while using AOT Inductor APIs such as +``torch._inductor.aoti_compile_and_package``, ``torch._indcutor.aoti_load_package``, +or running the loaded model of ``aoti_load_package`` on some inputs, you can use the AOTInductor Minifier +to create a minimal nn.Module that reproduce the error by setting ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True``. + + +One a high-level, there are two steps in using the minifier: + +- Set ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True`` or set the environment variable ``DUMP_AOTI_MINIFIER=1``. Then running the script that errors would produce a ``minifier_launcher.py`` script. The output directory is configurable by setting ``torch._dynamo.config.base_dir`` to a valid directory name. + +- Run the ``minifier_launcher.py`` script. If the minifier runs successfully, it generates runnable python code in ``repro.py`` which reproduces the exact error. + +Here is sample code which will generate an error because we injected an error on relu with +``torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"``. + + +.. code-block:: py + + import torch + from torch._inductor import config as inductor_config + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + return x + + + inductor_config.aot_inductor.dump_aoti_minifier = True + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error" + + with torch.no_grad(): + model = Model().to("cuda") + example_inputs = (torch.randn(8, 10).to("cuda"),) + ep = torch.export.export(model, example_inputs) + package_path = torch._inductor.aoti_compile_and_package(ep, example_inputs) + compiled_model = torch._inductor.aoti_load_package(package_path) + result = compiled_model(*example_inputs) + + +The code above generates the following error: + +:: + + RuntimeError: Failed to import /tmp/torchinductor_shangdiy/fr/cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py + SyntaxError: invalid syntax (cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py, line 29) + +This is because we injected an error on relu, and so the generated triton kernel looks like below. Note that we have ``compile error!`` +instead if ``relu``, so we get a ``SyntaxError``. + +.. code-block:: + + @triton.jit + def triton_poi_fused_addmm_relu_sigmoid_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 128 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = xindex % 16 + tmp0 = tl.load(in_out_ptr0 + (x2), xmask) + tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') + tmp2 = tmp0 + tmp1 + tmp3 = compile error! + tmp4 = tl.sigmoid(tmp3) + tl.store(in_out_ptr0 + (x2), tmp4, xmask) + + +Since we have ``torch._inductor.config.aot_inductor.dump_aoti_minifier=True``, we also see an additional line indicating where ``minifier_launcher.py`` has +been written to. The output directory is configurable by setting +``torch._dynamo.config.base_dir`` to a valid directory name. + +:: + + W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] Writing minified repro to: + W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_21_08_602433-pid_2861654/minifier/minifier_launcher.py + + +The ``minifier_launcher.py`` file has the following code. The ``exported_program`` contains the inputs to ``torch._inductor.aoti_compile_and_package``. +The ``command='minify'`` parameter means the script will run the minifier to create a minimal graph module that reproduce the error. Alternatively, you set +use ``command='run'`` to just compile, load, and run the loaded model (without running the minifier). + +.. code-block:: py + + import torch + import torch._inductor.inductor_prims + + import torch._dynamo.config + import torch._inductor.config + import torch._functorch.config + import torch.fx.experimental._config + + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' + torch._inductor.config.aot_inductor.dump_aoti_minifier = True + + + + + isolate_fails_code_str = None + + + + # torch version: 2.6.0a0+gitcd9c6e9 + # torch cuda version: 12.0 + # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 + + + # CUDA Info: + # nvcc: NVIDIA (R) Cuda compiler driver + # Copyright (c) 2005-2023 NVIDIA Corporation + # Built on Fri_Jan__6_16:45:21_PST_2023 + # Cuda compilation tools, release 12.0, V12.0.140 + # Build cuda_12.0.r12.0/compiler.32267302_0 + + # GPU Hardware Info: + # NVIDIA PG509-210 : 8 + + exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints/exported_program.pt2') + # print(exported_program.graph) + config_patches={} + if __name__ == '__main__': + from torch._dynamo.repro.aoti import run_repro + with torch.no_grad(): + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='minify', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints', check_str=None) + + +Suppose we kept the ``command='minify'`` option, and run the script, we would get the following output: + +:: + + ... + W1031 16:48:08.938000 3598491 torch/_dynamo/repro/aoti.py:89] Writing checkpoint with 3 nodes to /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/3.py + W1031 16:48:08.975000 3598491 torch/_dynamo/repro/aoti.py:101] Copying repro file for convenience to /data/users/shangdiy/pytorch/repro.py + Wrote minimal repro out to repro.py + + +The ``repro.py`` looks like this. The exported program now contains only the relu node. The minifier successfully reduced the graph to the op that raises the +error. + +.. code-block:: py + + import torch + from torch import tensor, device + import torch.fx as fx + from torch._dynamo.testing import rand_strided + from math import inf + import torch._inductor.inductor_prims + + import torch._dynamo.config + import torch._inductor.config + import torch._functorch.config + import torch.fx.experimental._config + + torch._inductor.config.generate_intermediate_hooks = True + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' + torch._inductor.config.aot_inductor.dump_aoti_minifier = True + + + + + isolate_fails_code_str = None + + + + # torch version: 2.6.0a0+gitcd9c6e9 + # torch cuda version: 12.0 + # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 + + + # CUDA Info: + # nvcc: NVIDIA (R) Cuda compiler driver + # Copyright (c) 2005-2023 NVIDIA Corporation + # Built on Fri_Jan__6_16:45:21_PST_2023 + # Cuda compilation tools, release 12.0, V12.0.140 + # Build cuda_12.0.r12.0/compiler.32267302_0 + + # GPU Hardware Info: + # NVIDIA PG509-210 : 8 + + + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + + + def forward(self, linear): + relu = torch.ops.aten.relu.default(linear); linear = None + return (relu,) + + def load_args(reader): + buf0 = reader.storage('a4e748c3a3d0d4a78cde43e33ad0f9dd41d96e90', 512, device=device(type='cuda', index=0)) + reader.tensor(buf0, (8, 16), is_leaf=True) # linear + load_args._version = 0 + mod = Repro() + if __name__ == '__main__': + from torch._dynamo.repro.aoti import run_repro, repro_load_args + config_patches={} + with torch.no_grad(): + args = repro_load_args(load_args, save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints') + exported_program = torch.export.export(mod, args) + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='run', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints', check_str=None) diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index 45d4a79decff14..07262118cf63d1 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -170,6 +170,78 @@ def inner(x): minifier_args=["--offload-to-disk"], ) + # Test that compile errors in AOTInductor can be repro'd (both CPU and CUDA) + def _test_aoti(self, device, expected_error): + # NB: The program is intentionally quite simple, just enough to + # trigger one minification step, no more (dedicated minifier tests + # should exercise minifier only) + run_code = f"""\ +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + return x +with torch.no_grad(): + model = Model().to("{device}") + example_inputs = (torch.randn(8, 10).to("{device}"),) + ep = torch.export.export( + model, example_inputs + ) + torch._inductor.aoti_compile_and_package( + ep, example_inputs + ) +""" + return self._run_full_test(run_code, None, expected_error, isolate=True) + + @unittest.skipIf(IS_JETSON, "Fails on Jetson") + @inductor_config.patch( + { + "cpp.inject_relu_bug_TESTING_ONLY": "compile_error", + "aot_inductor.dump_aoti_minifier": True, + } + ) + def test_aoti_cpu_compile_error(self): + res = self._test_aoti("cpu", "CppCompileError") + self.assertExpectedInline( + res.repro_module(), + """\ +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, linear): + relu = torch.ops.aten.relu.default(linear); linear = None + return (relu,)""", + ) + + @requires_gpu + @inductor_config.patch( + { + "triton.inject_relu_bug_TESTING_ONLY": "compile_error", + "aot_inductor.dump_aoti_minifier": True, + } + ) + def test_aoti_gpu_compile_error(self): + res = self._test_aoti(GPU_TYPE, "SyntaxError") + self.assertExpectedInline( + res.repro_module(), + """\ +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, linear): + relu = torch.ops.aten.relu.default(linear); linear = None + return (relu,)""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/repro/aoti.py b/torch/_dynamo/repro/aoti.py new file mode 100644 index 00000000000000..605ad153b351d5 --- /dev/null +++ b/torch/_dynamo/repro/aoti.py @@ -0,0 +1,455 @@ +# mypy: allow-untyped-defs +import argparse +import functools +import io +import logging +import os +import shutil +import sys +import textwrap +from importlib import import_module +from typing import Any, Dict, Optional, Union + +import torch +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + BuckTargetWriter, + extra_imports, + generate_config_string, + helper_for_dump_minify, + InputReader, + minifier_dir, + NopInputReader, +) +from torch.export import ExportedProgram +from torch.hub import tqdm + +from .after_aot import generate_compiler_repro_string + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + + +def dump_to_minify( + exported_program: ExportedProgram, + compiler_name: str, + options: Optional[Dict[str, Any]] = None, +): + out = io.StringIO() + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + save_graph_repro_ep( + out, + exported_program, + compiler_name, + save_dir=subdir, + command="minify", + options=options, + ) + return helper_for_dump_minify(out.getvalue()) + + +def save_graph_repro_ep( + fd, + exported_program: ExportedProgram, + compiler_name, + *, + options: Optional[Dict[str, str]] = None, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + check_str=None, +): + # save a graph repro using exported_program + fd.write( + generate_compiler_repro_exported_program( + exported_program, + options=options, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.aoti import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, check_str={check_str!r})\n" + ) + + +def save_graph_repro_string( + fd, + gm, + args, + compiler_name, + *, + config_patches=None, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + tracing_mode=None, + check_str=None, +): + # save a graph repro by dumping the `gm` as a string + if any( + isinstance(arg, torch.fx.experimental._backward_state.BackwardState) + for arg in args + ): + fd.write( + "Repro is not generated due to existence of BackwardState in graph input" + ) + return + fd.write( + generate_compiler_repro_string( + gm, + args, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.aoti import run_repro, repro_load_args\n") + fd.write( + f" config_patches={config_patches}\n" + f" with torch.no_grad():\n" + f" args = repro_load_args(load_args, save_dir={save_dir!r})\n" + f" exported_program = torch.export.export(mod, args)\n" + f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, check_str={check_str!r})\n" + ) + + +def dump_compiler_graph_state( + gm, args, compiler_name, *, config_patches=None, accuracy=None +): + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + # exported_program = torch.export.export(gm, tuple(args)) + with open(file_name, "w") as fd: + save_graph_repro_string( + fd, + gm, + args, + compiler_name, + config_patches=config_patches, + save_dir=subdir, + accuracy=accuracy, + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_compiler_repro_exported_program( + exported_program, + *, + options: Optional[Dict[str, str]] = None, + stable_output=False, + save_dir=None, +): + model_str = textwrap.dedent( + f""" +import torch +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + ep_path = os.path.join(save_dir, "exported_program.pt2") + torch.export.save(exported_program, ep_path) + + model_str += f"exported_program = torch.export.load('{ep_path}')\n" + model_str += "# print(exported_program.graph)\n" + model_str += f"config_patches={options}\n" + return model_str + + +def repro_load_args(load_args, save_dir): + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return tuple(args) + + +def repro_common(options, exported_program): + torch._inductor.config.generate_intermediate_hooks = True + mod = exported_program.module() + args, kwargs = exported_program.example_inputs + return mod, args, kwargs + + +def repro_get_args(options, exported_program, config_patches): + mod, args, kwargs = repro_common(options, exported_program) + return mod, args, kwargs + + +def repro_run(options, exported_program, config_patches): + from torch._inductor import _aoti_compile_and_package_inner, aoti_load_package + + mod, args, kwargs = repro_common(options, exported_program) + + from torch.cuda import synchronize + + package_path = _aoti_compile_and_package_inner( + mod, + args, + kwargs, + load_and_run=False, + inductor_configs=config_patches, + ) + compiled = aoti_load_package(package_path) + assert not isinstance(compiled, str) + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + compiled(*args) + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +def repro_minify(options, exported_program, config_patches): + from functorch.compile import minifier + from torch._inductor import _aoti_compile_and_package_inner + + mod, args, kwargs = repro_common(options, exported_program) + compiler_name = "aot_inductor" + + from torch.cuda import synchronize + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + def module_fails(gm, flat_example_inputs, check_str=None): + # we have to export first so the in_spec and out_spec are populated + tuple_inputs = tuple(flat_example_inputs) + ep = torch.export.export(gm, tuple_inputs) + gm = ep.module() + try: + _aoti_compile_and_package_inner( + gm, + tuple_inputs, + kwargs, + load_and_run=True, + inductor_configs=config_patches, + ) + if need_sync: + synchronize() # ensure segfaults are surfaced + return False + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + return True + + minifier( + mod, + args, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, + compiler_name=compiler_name, + config_patches=config_patches, + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def run_repro( + exported_program, + # load_args, + # kwargs: Dict[str, Any], + *, + config_patches: Optional[Dict[str, str]] = None, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + check_str=None, + **more_kwargs, +): + for k in more_kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + raise NotImplementedError("check for accuracy is not supported yet") + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An AOTI repro script, typically triggering a bug in PyTorch AOTInductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify,analyze}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command]( + options, exported_program, config_patches=config_patches + ) diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index b8e31ee0f578dd..a3eaeb685400c3 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -32,6 +32,18 @@ def _get_module(self, t): r = re.sub(r"\n{3,}", "\n\n", r) return r.strip() + def get_exported_program_path(self): + # Extract the exported program file path from AOTI minifier's repro.py + # Regular expression pattern to match the file path + pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)' + # Search for the pattern in the text + match = re.search(pattern, self.repro_code) + # Extract and print the file path if a match is found + if match: + file_path = match.group(1) + return file_path + return None + def minifier_module(self): return self._get_module(self.minifier_code) @@ -197,12 +209,19 @@ def _run_repro(self, repro_dir, *, isolate=True): # `patch_code` is the code to be patched in every generated file; usually # just use this to turn on bugs via the config def _gen_test_code(self, run_code, repro_after, repro_level): + repro_after_line = ( + f"""\ +torch._dynamo.config.repro_after = "{repro_after}" +""" + if repro_after + else "" + ) return f"""\ import torch import torch._dynamo {_as_posix_path(torch._dynamo.config.codegen_config())} {_as_posix_path(torch._inductor.config.codegen_config())} -torch._dynamo.config.repro_after = "{repro_after}" +{repro_after_line} torch._dynamo.config.repro_level = {repro_level} torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" {run_code} diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 397739147c13bf..4c791494e8bc3d 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +import torch._inductor.config import torch.fx import torch.utils._pytree as pytree @@ -10,7 +11,13 @@ from torch._inductor.utils import InputType -__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] +__all__ = [ + "compile", + "list_mode_options", + "list_options", + "cudagraph_mark_step_begin", + "_aoti_compile_and_package_inner", +] def compile( @@ -73,7 +80,6 @@ def aoti_compile_and_package( Returns: Path to the generated artifact """ - from torch._inductor.package import package_aoti from torch.export import ExportedProgram if not isinstance(exported_program, ExportedProgram): @@ -90,10 +96,36 @@ def aoti_compile_and_package( "Please pass in a package path to aot_inductor_compile() instead " "of setting the aot_inductor.output_path config." ) - inductor_configs["aot_inductor.package"] = True - m = exported_program.module() - assert isinstance(m, torch.fx.GraphModule) + # a wrapper around aoti_compile_and_package_inner. + return aoti_compile_and_package_debug_wrapper( + exported_program, + args, + kwargs, + package_path=package_path, + inductor_configs=inductor_configs, + ) + + +def _aoti_compile_and_package_inner( + m, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + load_and_run: bool = False, + package_path: Optional[str] = None, + inductor_configs: Optional[Dict[str, Any]] = None, +): + """ + See docstring for aoti_compile_and_package. + + If `load_and_run` is True, this function will load the compiled model and run it. + This is for the minifier to check the correctness of the compiled model. + """ + from torch._inductor.package import package_aoti + + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package"] = True aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type] @@ -102,9 +134,52 @@ def aoti_compile_and_package( res = package_aoti(package_path, aoti_files) assert res == package_path + + if load_and_run: + compiled_model = aoti_load_package(package_path) + aoti_result = compiled_model(*args) return package_path +def aoti_compile_and_package_debug_wrapper( + exported_program, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + package_path: Optional[str] = None, + inductor_configs: Optional[Dict[str, Any]] = None, +): + m = exported_program.module() + assert isinstance(m, torch.fx.GraphModule) + + use_minifier = torch._inductor.config.aot_inductor.dump_aoti_minifier + + try: + return _aoti_compile_and_package_inner( + m, + args, + kwargs, + load_and_run=use_minifier, + package_path=package_path, + inductor_configs=inductor_configs, + ) + + except Exception as e: + if use_minifier: + # TODO: check accuracy and re-direct to minifier + from torch._dynamo.repro.aoti import dump_to_minify + + exported_program._example_inputs = (args, kwargs) + + dump_to_minify( + exported_program, + "compile_fx_aot", + options=inductor_configs, + ) + + raise e + + def aoti_load_package(path: str) -> Any: # type: ignore[type-arg] """ Loads the model from the PT2 package. diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 54cf3d8e8c3f2b..82dc02b862d2eb 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1081,6 +1081,9 @@ class aot_inductor: os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1" ) + # dump an aoti minifier if program errors + dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1" + # Dictionary of presets that can be passed in presets: Dict[str, Any] = {}