-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: (convert-memref-to-ptr) add lower-func flag #3820
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2c35423
transforms: (convert_memref_to_ptr) add func arg rewrite option
kaylendog 82b192f
transforms: (convert_memref_to_ptr) add filecheck tests
kaylendog 5bd9685
feat: add reconcile ptr casts
kaylendog e40ba00
feat: unrealized ptr cast reconciliation
kaylendog aa649c0
fix: fixup return receiver cast
kaylendog 3824049
Merge branch 'xdslproject:main' into main
kaylendog f74d847
tests: update filecheck to include id2
kaylendog f312d13
Merge branch 'main' of https://github.com/kaylendog/xdsl
kaylendog 2f4d5b7
chore: rename flag
kaylendog 450073d
chore: update test
kaylendog d79167a
fix: address PR comments from sasha
kaylendog aed42b5
fix: correct double negation, tuple instantiation
kaylendog File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
42 changes: 42 additions & 0 deletions
42
tests/filecheck/transforms/convert_memref_args_to_ptr.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// RUN: xdsl-opt -p convert-memref-to-ptr{lower_func=true} --split-input-file --verify-diagnostics %s | filecheck %s | ||
|
||
// CHECK: builtin.module { | ||
compor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// CHECK-NEXT: func.func @declaration(!ptr_xdsl.ptr) -> () | ||
func.func @declaration(%arg : memref<2x2xf32>) | ||
|
||
|
||
// CHECK-NEXT: func.func @simple(%arg : !ptr_xdsl.ptr) { | ||
// CHECK-NEXT: func.return | ||
// CHECK-NEXT: } | ||
compor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
func.func @simple(%arg : memref<2x2xf32>) { | ||
func.return | ||
} | ||
|
||
// CHECK-NEXT: func.func @id(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr { | ||
// CHECK-NEXT: func.return %arg : !ptr_xdsl.ptr | ||
// CHECK-NEXT: } | ||
func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> { | ||
func.return %arg : memref<2x2xf32> | ||
} | ||
|
||
// CHECK-NEXT: func.func @id2(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr { | ||
// CHECK-NEXT: %res = func.call @id(%arg) : (!ptr_xdsl.ptr) -> !ptr_xdsl.ptr | ||
// CHECK-NEXT: func.return %res : !ptr_xdsl.ptr | ||
// CHECK-NEXT: } | ||
func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { | ||
%res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> | ||
func.return %res : memref<2x2xf32> | ||
} | ||
|
||
// CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 { | ||
// CHECK-NEXT: %res = ptr_xdsl.load %arg : !ptr_xdsl.ptr -> f32 | ||
// CHECK-NEXT: func.return %res : f32 | ||
// CHECK-NEXT: } | ||
func.func @first(%arg : memref<2x2xf32>) -> f32 { | ||
%pointer = ptr_xdsl.to_ptr %arg : memref<2x2xf32> -> !ptr_xdsl.ptr | ||
%res = ptr_xdsl.load %pointer : !ptr_xdsl.ptr -> f32 | ||
func.return %res : f32 | ||
} | ||
|
||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,8 @@ | |
from typing import cast | ||
|
||
from xdsl.context import MLContext | ||
from xdsl.dialects import arith, builtin, memref, ptr | ||
from xdsl.ir import Operation, SSAValue | ||
from xdsl.dialects import arith, builtin, func, memref, ptr | ||
from xdsl.ir import Attribute, Operation, SSAValue | ||
from xdsl.irdl import Any | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
|
@@ -14,6 +14,7 @@ | |
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.rewriter import InsertPoint | ||
from xdsl.utils.exceptions import DiagnosticException | ||
|
||
|
||
|
@@ -153,12 +154,197 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /): | |
rewriter.replace_matched_op(ops, new_results=[load_result.res]) | ||
|
||
|
||
@dataclass | ||
class LowerMemrefFuncOpPattern(RewritePattern): | ||
""" | ||
Rewrites function arguments of MemRefType to PtrType. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): | ||
# rewrite function declaration | ||
new_input_types = [ | ||
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg | ||
for arg in op.function_type.inputs | ||
] | ||
new_output_types = [ | ||
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg | ||
for arg in op.function_type.outputs | ||
] | ||
op.function_type = func.FunctionType.from_lists( | ||
new_input_types, | ||
new_output_types, | ||
) | ||
|
||
if op.is_declaration: | ||
return | ||
|
||
insert_point = InsertPoint.at_start(op.body.blocks[0]) | ||
|
||
# rewrite arguments | ||
for arg in op.args: | ||
if not isinstance(arg_type := arg.type, memref.MemRefType): | ||
continue | ||
|
||
old_type = cast(memref.MemRefType[Attribute], arg_type) | ||
arg.type = ptr.PtrType() | ||
|
||
if not arg.uses: | ||
continue | ||
|
||
rewriter.insert_op( | ||
cast_op := builtin.UnrealizedConversionCastOp.get([arg], [old_type]), | ||
insert_point, | ||
) | ||
arg.replace_by_if(cast_op.results[0], lambda x: x.operation is not cast_op) | ||
|
||
|
||
@dataclass | ||
class LowerMemrefFuncReturnPattern(RewritePattern): | ||
""" | ||
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType` | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): | ||
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments): | ||
return | ||
|
||
insert_point = InsertPoint.before(op) | ||
new_arguments: list[SSAValue] = [] | ||
|
||
# insert `memref -> ptr` casts for memref return values | ||
for argument in op.arguments: | ||
if isinstance(argument.type, memref.MemRefType): | ||
rewriter.insert_op( | ||
cast_op := builtin.UnrealizedConversionCastOp.get( | ||
[argument], [ptr.PtrType()] | ||
), | ||
insert_point, | ||
) | ||
new_arguments.append(cast_op.results[0]) | ||
else: | ||
new_arguments.append(argument) | ||
|
||
rewriter.replace_matched_op(func.ReturnOp(*new_arguments)) | ||
|
||
|
||
@dataclass | ||
class LowerMemrefFuncCallPattern(RewritePattern): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would love a docstring here |
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): | ||
if not any( | ||
isinstance(arg.type, memref.MemRefType) for arg in op.arguments | ||
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types): | ||
return | ||
|
||
# rewrite arguments | ||
insert_point = InsertPoint.before(op) | ||
new_arguments: list[SSAValue] = [] | ||
|
||
# insert `memref -> ptr` casts for memref arguments values | ||
for argument in op.arguments: | ||
if isinstance(argument.type, memref.MemRefType): | ||
rewriter.insert_op( | ||
cast_op := builtin.UnrealizedConversionCastOp.get( | ||
[argument], [ptr.PtrType()] | ||
), | ||
insert_point, | ||
) | ||
new_arguments.append(cast_op.results[0]) | ||
else: | ||
new_arguments.append(argument) | ||
|
||
insert_point = InsertPoint.after(op) | ||
new_results: list[SSAValue] = [] | ||
|
||
# insert `ptr -> memref` casts for return values | ||
for result in op.results: | ||
if isinstance(result.type, memref.MemRefType): | ||
rewriter.insert_op( | ||
cast_op := builtin.UnrealizedConversionCastOp.get( | ||
[result], | ||
# TODO: annoying pyright warnings - Sasha, pls help | ||
[result.type], # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType] | ||
), | ||
insert_point, | ||
) | ||
new_results.append(cast_op.results[0]) | ||
else: | ||
new_results.append(result) | ||
|
||
new_return_types = [ | ||
ptr.PtrType() if isinstance(type, memref.MemRefType) else type | ||
for type in op.result_types | ||
] | ||
|
||
rewriter.replace_matched_op( | ||
func.CallOp(op.callee, new_arguments, new_return_types) | ||
) | ||
|
||
|
||
class ReconcileUnrealizedPtrCasts(RewritePattern): | ||
""" | ||
Eliminates two variants of unrealized ptr casts: | ||
- `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`; | ||
- `ptr_xdsl.ptr -> memref.memref` where all uses are `ToPtrOp` operations. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite( | ||
self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, / | ||
): | ||
# preconditions | ||
if ( | ||
len(op.inputs) != 1 | ||
or len(op.outputs) != 1 | ||
or not isinstance(op.inputs[0].type, ptr.PtrType) | ||
or not isinstance(op.outputs[0].type, memref.MemRefType) | ||
): | ||
return | ||
|
||
# erase ptr -> memref -> ptr cast pairs | ||
uses = tuple(use for use in op.outputs[0].uses) | ||
for use in uses: | ||
if ( | ||
isinstance(use.operation, builtin.UnrealizedConversionCastOp) | ||
and isinstance(use.operation.inputs[0].type, memref.MemRefType) | ||
and isinstance(use.operation.outputs[0].type, ptr.PtrType) | ||
): | ||
use.operation.outputs[0].replace_by(op.inputs[0]) | ||
rewriter.erase_op(use.operation) | ||
|
||
# erase this cast entirely if all remaining uses are by ToPtr operations | ||
cast_ops = [use.operation for use in op.outputs[0].uses] | ||
if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops): | ||
return | ||
|
||
for cast_op in cast_ops: | ||
cast_op.results[0].replace_by(op.inputs[0]) | ||
rewriter.erase_op(cast_op) | ||
|
||
rewriter.erase_op(op) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ConvertMemrefToPtr(ModulePass): | ||
name = "convert-memref-to-ptr" | ||
|
||
lower_func: bool = False | ||
|
||
Comment on lines
+333
to
+334
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would also love a docstring here |
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
the_one_pass = PatternRewriteWalker( | ||
PatternRewriteWalker( | ||
GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()]) | ||
) | ||
the_one_pass.rewrite_module(op) | ||
).rewrite_module(op) | ||
|
||
if self.lower_func: | ||
PatternRewriteWalker( | ||
GreedyRewritePatternApplier( | ||
[ | ||
LowerMemrefFuncOpPattern(), | ||
LowerMemrefFuncCallPattern(), | ||
LowerMemrefFuncReturnPattern(), | ||
ReconcileUnrealizedPtrCasts(), | ||
] | ||
) | ||
).rewrite_module(op) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
--split-input-file
allows to split each case separated byinto a separate file - useful to isolate test behaviour. So, I'd use it in between the
func.func
definitions below.