diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 85e614fc6a..898d770067 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -264,7 +264,7 @@ def __init__(self, options: Optional[CompileOptions] = None): self.options = options if options is not None else CompileOptions() @debug_logger - def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspace): + def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspace, *args): """Prepare the command for catalyst-cli to compile the file. Args: @@ -290,6 +290,9 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac cmd += ["--verbose"] if self.options.checkpoint_stage: cmd += ["--checkpoint-stage", self.options.checkpoint_stage] + if args: + for arg in args: + cmd += [str(arg)] pipeline_str = "" for pipeline in self.options.get_pipelines(): @@ -301,7 +304,7 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac return cmd @debug_logger - def run_from_ir(self, ir: str, module_name: str, workspace: Directory): + def run_from_ir(self, ir: str, module_name: str, workspace: Directory, *args): """Compile a shared object from a textual IR (MLIR or LLVM). Args: @@ -335,7 +338,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): output_object_name = os.path.join(str(workspace), f"{module_name}.o") output_ir_name = os.path.join(str(workspace), f"{module_name}{output_ir_ext}") - cmd = self.get_cli_command(tmp_infile_name, output_ir_name, module_name, workspace) + cmd = self.get_cli_command(tmp_infile_name, output_ir_name, module_name, workspace, *args) try: if self.options.verbose: print(f"[SYSTEM] {' '.join(cmd)}", file=self.options.logfile) diff --git a/frontend/catalyst/debug/compiler_functions.py b/frontend/catalyst/debug/compiler_functions.py index 2fa16b4218..75eeed818a 100644 --- a/frontend/catalyst/debug/compiler_functions.py +++ b/frontend/catalyst/debug/compiler_functions.py @@ -15,6 +15,9 @@ """ This module contains debug functions to interact with the compiler and compiled functions. """ +import dis +import functools +import inspect import logging import os import platform @@ -22,6 +25,8 @@ import shutil import subprocess +from jax._src.interpreters import mlir + import catalyst from catalyst.compiler import LinkerDriver from catalyst.logging import debug_logger @@ -183,6 +188,66 @@ def replace_ir(fn, stage, new_ir): fn.fn_cache.clear() +def get_docstring_ir(kallable, opts): + # This is necessary because of recursive imports... + class DocstringIR(catalyst.QJIT): + + def __init__(self, kallable, compile_options): + assert callable(kallable) + + docstring = inspect.getdoc(kallable) + assert docstring + + bytecode = dis.get_instructions(kallable) + + @functools.wraps(kallable) + def noop(*args, **kwargs): + """Dummy docstring""" + + super().__init__(noop, compile_options) + + noop_bytecode = dis.get_instructions(noop) + is_applicable = all( + itarget.opcode == itruth.opcode for itarget, itruth in zip(bytecode, noop_bytecode) + ) + assert is_applicable + + replace_ir(self, "mlir", docstring) + + options = catalyst.compiler.CompileOptions() + options.pipelines = [("Generic", ["builtin.module(canonicalize)"])] + options.lower_to_llvm = False + + print_generic = catalyst.compiler.Compiler(options) + _, generic_ir = print_generic.run_from_ir( + docstring, "generic", self.workspace, "--mlir-print-op-generic" + ) + + ctx = mlir.make_ir_context() + ctx.allow_unregistered_dialects = True + + self.aot_compile() + self.mlir_module = mlir.ir.Module.parse(generic_ir, context=ctx) + self.mlir = docstring + self.fn_cache.clear() + self.compile_options.keep_intermediate = True + self.compiled_function, self.qir = self.compile() + self.fn_cache.insert( + self.compiled_function, self.user_sig, self.out_treedef, self.workspace + ) + + return DocstringIR(kallable, opts) + + +@debug_logger +def docstringir(kallable): + """Decorator that denotes that a function's docstring is + actually IR.""" + + options = catalyst.compiler.CompileOptions() + return foo(kallable, options) + + @debug_logger def compile_executable(fn, *args): """Generate an executable binary for the native host architecture from a diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 86722edd56..18e456039b 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -543,6 +543,11 @@ def aot_compile(self): """Compile Python function on initialization using the type hint signature.""" self.workspace = self._get_workspace() + return_annotation = inspect.signature(self.original_function).return_annotation + if return_annotation and return_annotation != inspect.Signature.empty: + return_annotation = get_abstract_signature(return_annotation) + else: + return_annotation = None # TODO: awkward, refactor or redesign the target feature if self.compile_options.target in ("jaxpr", "mlir", "binary"): @@ -554,6 +559,10 @@ def aot_compile(self): self.user_sig or () ) + if return_annotation: + vals, self.out_treedef = tree_flatten(return_annotation) + self.out_type = [(True, retty) for retty in vals] + if self.compile_options.target in ("mlir", "binary"): self.mlir_module, self.mlir = self.generate_ir() diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 46efd76162..458d9a2f57 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -143,7 +143,7 @@ def get_stages(self): """Returns all stages in order for compilation""" # Dictionaries in python are ordered stages = {} - stages["EnforeRuntimeInvariantsPass"] = get_enforce_runtime_invariants_stage(self) + stages["EnforceRuntimeInvariantsPass"] = get_enforce_runtime_invariants_stage(self) stages["HLOLoweringPass"] = get_hlo_lowering_stage(self) stages["QuantumCompilationPass"] = get_quantum_compilation_stage(self) stages["BufferizationPass"] = get_bufferization_stage(self)