diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 842673e5..61150dd2 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -155,7 +155,6 @@ def aval_size_bytes(aval): def compile_ttir_to_ptx_inplace( ttir, - tl_context: tl_ir.Context, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, device: int = 0, @@ -165,14 +164,18 @@ def compile_ttir_to_ptx_inplace( if cuda_options.debug: print(ttir) if isinstance(ttir, ir.Module): + context = _triton.ir.context() + _triton.ir.load_dialects(context) + cuda_backend.load_dialects(context) + # Triton compilation APIs only accept Triton-specific MLIR wrappers. # So, here we serialize an ir.Module to a file and then deserialize # it as a tl_ir.module. with tempfile.NamedTemporaryFile(mode="wb") as f: ttir.operation.write_bytecode(f) f.flush() - ttir = tl_ir.parse_mlir_module(f.name, tl_context) - ttir.context = tl_context + ttir = tl_ir.parse_mlir_module(f.name, context) + ttir.context = context try: metadata = dict() opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) @@ -310,7 +313,6 @@ def get_or_create_triton_kernel( ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = ( compile_ttir_to_ptx_inplace( module, - context, cuda_backend, cuda_options, device=device,