Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Compile IR from docstring. #1334

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, options: Optional[CompileOptions] = None):
self.options = options if options is not None else CompileOptions()

@debug_logger
def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspace):
def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspace, *args):
"""Prepare the command for catalyst-cli to compile the file.

Args:
Expand All @@ -290,6 +290,9 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac
cmd += ["--verbose"]
if self.options.checkpoint_stage:
cmd += ["--checkpoint-stage", self.options.checkpoint_stage]
if args:
for arg in args:
cmd += [str(arg)]

pipeline_str = ""
for pipeline in self.options.get_pipelines():
Expand All @@ -301,7 +304,7 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac
return cmd

@debug_logger
def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
def run_from_ir(self, ir: str, module_name: str, workspace: Directory, *args):
"""Compile a shared object from a textual IR (MLIR or LLVM).

Args:
Expand Down Expand Up @@ -335,7 +338,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
output_object_name = os.path.join(str(workspace), f"{module_name}.o")
output_ir_name = os.path.join(str(workspace), f"{module_name}{output_ir_ext}")

cmd = self.get_cli_command(tmp_infile_name, output_ir_name, module_name, workspace)
cmd = self.get_cli_command(tmp_infile_name, output_ir_name, module_name, workspace, *args)
try:
if self.options.verbose:
print(f"[SYSTEM] {' '.join(cmd)}", file=self.options.logfile)
Expand Down
65 changes: 65 additions & 0 deletions frontend/catalyst/debug/compiler_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
"""
This module contains debug functions to interact with the compiler and compiled functions.
"""
import dis
import functools
import inspect
import logging
import os
import platform
import re
import shutil
import subprocess

from jax._src.interpreters import mlir

import catalyst
from catalyst.compiler import LinkerDriver
from catalyst.logging import debug_logger
Expand Down Expand Up @@ -183,6 +188,66 @@
fn.fn_cache.clear()


def get_docstring_ir(kallable, opts):

Check notice on line 191 in frontend/catalyst/debug/compiler_functions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/debug/compiler_functions.py#L191

Missing function or method docstring (missing-function-docstring)
# This is necessary because of recursive imports...
class DocstringIR(catalyst.QJIT):

Check notice on line 193 in frontend/catalyst/debug/compiler_functions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/debug/compiler_functions.py#L193

Missing class docstring (missing-class-docstring)

def __init__(self, kallable, compile_options):
assert callable(kallable)

docstring = inspect.getdoc(kallable)
assert docstring

bytecode = dis.get_instructions(kallable)

@functools.wraps(kallable)
def noop(*args, **kwargs):

Check notice on line 204 in frontend/catalyst/debug/compiler_functions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/debug/compiler_functions.py#L204

Unused argument 'args' (unused-argument)

Check notice on line 204 in frontend/catalyst/debug/compiler_functions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/debug/compiler_functions.py#L204

Unused argument 'kwargs' (unused-argument)
"""Dummy docstring"""

super().__init__(noop, compile_options)

noop_bytecode = dis.get_instructions(noop)
is_applicable = all(
itarget.opcode == itruth.opcode for itarget, itruth in zip(bytecode, noop_bytecode)
)
assert is_applicable

replace_ir(self, "mlir", docstring)

options = catalyst.compiler.CompileOptions()
options.pipelines = [("Generic", ["builtin.module(canonicalize)"])]
options.lower_to_llvm = False

print_generic = catalyst.compiler.Compiler(options)
_, generic_ir = print_generic.run_from_ir(
docstring, "generic", self.workspace, "--mlir-print-op-generic"
)

ctx = mlir.make_ir_context()
ctx.allow_unregistered_dialects = True

self.aot_compile()
self.mlir_module = mlir.ir.Module.parse(generic_ir, context=ctx)
self.mlir = docstring
self.fn_cache.clear()
self.compile_options.keep_intermediate = True
self.compiled_function, self.qir = self.compile()
self.fn_cache.insert(
self.compiled_function, self.user_sig, self.out_treedef, self.workspace
)

return DocstringIR(kallable, opts)


@debug_logger
def docstringir(kallable):
"""Decorator that denotes that a function's docstring is
actually IR."""

options = catalyst.compiler.CompileOptions()
return foo(kallable, options)

Check notice on line 248 in frontend/catalyst/debug/compiler_functions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/debug/compiler_functions.py#L248

Undefined variable 'foo' (undefined-variable)


@debug_logger
def compile_executable(fn, *args):
"""Generate an executable binary for the native host architecture from a
Expand Down
9 changes: 9 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ def aot_compile(self):
"""Compile Python function on initialization using the type hint signature."""

self.workspace = self._get_workspace()
return_annotation = inspect.signature(self.original_function).return_annotation
if return_annotation and return_annotation != inspect.Signature.empty:
return_annotation = get_abstract_signature(return_annotation)
else:
return_annotation = None

# TODO: awkward, refactor or redesign the target feature
if self.compile_options.target in ("jaxpr", "mlir", "binary"):
Expand All @@ -554,6 +559,10 @@ def aot_compile(self):
self.user_sig or ()
)

if return_annotation:
vals, self.out_treedef = tree_flatten(return_annotation)
self.out_type = [(True, retty) for retty in vals]

if self.compile_options.target in ("mlir", "binary"):
self.mlir_module, self.mlir = self.generate_ir()

Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_stages(self):
"""Returns all stages in order for compilation"""
# Dictionaries in python are ordered
stages = {}
stages["EnforeRuntimeInvariantsPass"] = get_enforce_runtime_invariants_stage(self)
stages["EnforceRuntimeInvariantsPass"] = get_enforce_runtime_invariants_stage(self)
stages["HLOLoweringPass"] = get_hlo_lowering_stage(self)
stages["QuantumCompilationPass"] = get_quantum_compilation_stage(self)
stages["BufferizationPass"] = get_bufferization_stage(self)
Expand Down
Loading