Skip to content

Commit

Permalink
update torch-mlir mul_cpu.py e2e test and python async backend to wor…
Browse files Browse the repository at this point in the history
…k with current torch-mlir (#778)

* update torch-mlir/mul_cpu.py to work with current torch-mlir

* black format
  • Loading branch information
fifield authored Nov 18, 2024
1 parent baf85eb commit 9629b46
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 79 deletions.
127 changes: 56 additions & 71 deletions python/air/backend/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import torch
import torch_mlir.ir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch_mlir.passmanager
from torch_mlir import torchscript

import air.mlir.ir
import air.mlir.passmanager
import air.ir
import air.passmanager

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
Expand All @@ -25,12 +26,14 @@
from typing import List

path = Path(air.backend.__file__).resolve().parent
ctypes.CDLL(f"{path}/../../../runtime_lib/aircpu/libaircpu.so", mode=ctypes.RTLD_GLOBAL)
ctypes.CDLL(
f"/FIXME/PATH/TO/llvm/lib/libmlir_async_runtime.so.17git", mode=ctypes.RTLD_GLOBAL
f"{path}/../../../runtime_lib/x86_64/aircpu/libaircpu.so", mode=ctypes.RTLD_GLOBAL
)
ctypes.CDLL(
f"/FIXME/PATH/TO/llvm/lib/libmlir_async_runtime.so.20.0git", mode=ctypes.RTLD_GLOBAL
)

__all__ = ["AirCpuBackend", "make_dynamo_backend", "DEFAULT_PIPELINE"]
__all__ = ["AirCpuBackend", "DEFAULT_PIPELINE"]

DEFAULT_PIPELINE = (
"builtin.module(" + ",".join(["air-to-async", "canonicalize", "cse"]) + ")"
Expand All @@ -52,6 +55,7 @@
+ ")"
)

# copied from torch-mlir
REF_BACKEND_LOWERING_PIPELINE = (
"builtin.module("
+ ",".join(
Expand Down Expand Up @@ -91,7 +95,6 @@
"func.func(convert-math-to-llvm)",
# Handle some complex mlir::math ops (e.g. atan2)
"convert-math-to-libm",
"convert-linalg-to-llvm",
"expand-strided-metadata",
"finalize-memref-to-llvm",
"lower-affine",
Expand Down Expand Up @@ -123,7 +126,7 @@ def __del__(self):

def compile(
self,
air_module: air.mlir.ir.Module,
air_module: air.ir.Module,
pipeline=None,
verbose=False,
segment_offset=None,
Expand All @@ -133,7 +136,7 @@ def compile(
The module is expected to be AIR dialect.
Args:
imported_module: The MLIR module consisting of functions containing
air_module: The MLIR module consisting of functions containing
AIR dialect.
pipeline: The custom lowering pipeline to use for lowering.
The default is `air.backend.cpu_backend.DEFAULT_PIPELINE`
Expand All @@ -148,21 +151,25 @@ def compile(
pipeline = DEFAULT_PIPELINE

s = str(air_module)
with air_module.context:
with air.ir.Context() as ctx:
ctx.allow_unregistered_dialects = True
# make a copy of the input MLIR
air_module = air.mlir.ir.Module.parse(s)
air_module = air.ir.Module.parse(s)

if verbose:
print("Running MLIR pass pipeline: ", pipeline)

pm = air.mlir.passmanager.PassManager.parse(pipeline)
pm.run(air_module.operation)
if callable(pipeline):
air_module = pipeline(air_module)
else:
pm = air.passmanager.PassManager.parse(pipeline)
pm.run(air_module.operation)

if verbose:
print("Async Module:")
print(air_module)

pm = air.mlir.passmanager.PassManager.parse(ASYNC_TO_LLVM_PIPELINE)
pm = air.passmanager.PassManager.parse(ASYNC_TO_LLVM_PIPELINE)
pm.run(air_module.operation)

if verbose:
Expand All @@ -175,71 +182,49 @@ def compile(
pm.run(torch_mlir_module.operation)
return torch_mlir_module

def load(self, module):
"""Load a compiled artifact."""
return self.refbackend.load(module)

def unload(self):
"""Unload any loaded module and release resources."""
pass


def make_dynamo_backend(pipeline=None, verbose=False):
"""Make a PyTorch dynamo backend using AirCpuBackend.
Args:
pipeline: The custom lowering pipeline to use for lowering. First
`air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE` is applied,
then `pipeline`.
The default is `air.backend.linalg_on_tensors.LINALG_MEMREF_TO_AIR_PIPELINE`
verbose: enable verbose output
segment_offset: default location for generated segments as [colOffset, rowOffset]
segment_size: default size for generated segments as [numCols, numRows]
Returns:
A PyTorch dynamo backend
"""
backend = AirCpuBackend()
def compile_from_torch_mlir(
self,
imported_module: torch_mlir.ir.Module,
pipeline=None,
verbose=False,
segment_offset=None,
segment_size=None,
):
if type(imported_module) is torch_mlir.ir.Module:
with imported_module.operation.context:
imported_module = torchscript.lower_mlir_module(
False, torchscript.OutputType.LINALG_ON_TENSORS, imported_module
)

@make_simple_dynamo_backend
def air_backend(fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
pm = torch_mlir.passmanager.PassManager.parse(
"builtin.module(refback-mlprogram-bufferize)"
)
pm.run(imported_module.operation)

# get the linalg mlir of the model from torch_mlir
mlir_module = torch_mlir.compile(
fx_graph,
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
if verbose:
print("Torch Module:")
print(imported_module)

with air.mlir.ir.Context():
air_module = air.mlir.ir.Module.parse(str(mlir_module))
pm = air.mlir.passmanager.PassManager.parse(
with air.ir.Context():
air_module = air.ir.Module.parse(str(imported_module))
pm = air.passmanager.PassManager.parse(
air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE
)
pm.run(air_module.operation)
if pipeline is None:
pm = air.mlir.passmanager.PassManager.parse(
linalg_on_tensors.LINALG_MEMREF_TO_AIR_PIPELINE
)
pm.run(air_module.operation)
elif callable(pipeline):
air_module = pipeline(air_module)
else:
pm = air.mlir.passmanager.PassManager.parse(pipeline)
pm.run(air_module.operation)

if verbose:
print("AIR Module:")
print(air_module)
print(
"Running MLIR pass pipeline: ",
air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE,
)

compiled = backend.compile(air_module, verbose=verbose)
pm.run(air_module.operation)

# return a function for invoking the compiled model
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
loaded = backend.load(compiled)
result = loaded.forward(*inputs)
return torch.from_numpy(result)
return self.compile(air_module, pipeline, verbose, segment_offset, segment_size)

return compiled_callable
def load(self, module):
"""Load a compiled artifact."""
return self.refbackend.load(module)

return air_backend
def unload(self):
"""Unload any loaded module and release resources."""
pass
5 changes: 4 additions & 1 deletion python/air/backend/linalg_on_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
print("[WARNING] We were not able to load .so for libairhost_shared.so")
print(e)
pass
import air._mlir_libs._airRt as airrt
try:
import air._mlir_libs._airRt as airrt
except Exception as e:
pass

__all__ = [
"LinalgOnTensorsAirBackend",
Expand Down
42 changes: 35 additions & 7 deletions python/test/torch_mlir_e2e/mul_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update
# REQUIRES: torch_mlir

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASSED
# CHECK: PASSED! 8/8

import torch
import torch._dynamo as dynamo
from torch_mlir import fx

from air.backend import cpu_backend as backend
from air.backend import cpu_backend
from air.passmanager import PassManager

verbose = False

Expand All @@ -25,17 +26,44 @@ def forward(self, a, b):
return x


air_backend = backend.make_dynamo_backend(verbose=verbose)
def pipeline(module):
with module.operation.context as ctx:
pipeline = (
"builtin.module("
+ ",".join(
[
"canonicalize",
"cse",
"air-linalg-codegen",
"air-par-to-herd{depth=0}",
"air-copy-to-dma",
"canonicalize",
"cse",
]
)
+ ")"
)
pm = PassManager.parse(pipeline)
pm.run(module.operation)
pm = PassManager.parse(cpu_backend.DEFAULT_PIPELINE)
pm.run(module.operation)
return module


def run_test(model, dtype, shape):
torch_model = model()
dynamo_model = dynamo.optimize(air_backend)(torch_model)

a = torch.randint(size=shape, low=1, high=100, dtype=dtype)
b = torch.randint(size=shape, low=1, high=100, dtype=dtype)
c = dynamo_model(a, b)
m = fx.export_and_import(torch_model, a, b, func_name="forward")

backend = cpu_backend.AirCpuBackend()
air_program = backend.load(
backend.compile_from_torch_mlir(m, pipeline=pipeline, verbose=verbose)
)

c_ref = torch_model(a, b)
c = torch.tensor(air_program.forward(a.numpy(), b.numpy()))

if verbose:
print(f"input:\n{a}\n{b}\noutput:\n{c}")
Expand Down

0 comments on commit 9629b46

Please sign in to comment.