Skip to content

Commit

Permalink
dialects: (linalg) let ConvOpsBase inherit from NamedOpsBase (#3841)
Browse files Browse the repository at this point in the history
This enables to correct printing of the hidden regions in generic format

Resolves conv ops in #2959, will be tested with #3837
  • Loading branch information
jorendumoulin authored and emmau678 committed Feb 6, 2025
1 parent a216535 commit 41b16cd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)
%18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%20 = "test.op"() : () -> (memref<64x4096xf32>)

%zero = arith.constant 0: f32
%zero = arith.constant 0.0 : f32
linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>)

linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)
Expand All @@ -72,6 +72,14 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou

%quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32>

%26, %27, %28 = "test.op"(): () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>)

%conv_2d_nchw_i = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins(%26, %27: tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>)
outs(%28: tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32>



// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()>
// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
Expand Down Expand Up @@ -117,4 +125,6 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
// CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: %21:3 = "test.op"() : () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>)
// CHECK-NEXT: %22 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%21#0, %21#1 : tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>) outs(%21#2 : tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32>
// CHECK-NEXT: }
6 changes: 4 additions & 2 deletions tests/interpreters/test_linalg_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,16 @@ def test_linalg_conv_2d_nchw_fchw():
interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(LinalgFunctions())
op = linalg.Conv2DNchwFchwOp(
DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
(
TestSSAValue(TensorType(f32, [1, 1, 5, 5])),
TestSSAValue(TensorType(f32, [1, 1, 3, 3])),
),
(TestSSAValue(TensorType(f32, [1, 1, 3, 3])),),
(TensorType(f32, [1, 1, 3, 3]),),
{
"dilations": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
"strides": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
},
)
a = ShapedArray(TypedPtr.new_float32(list(range(25))), [1, 1, 5, 5])
b = ShapedArray(
Expand Down
52 changes: 28 additions & 24 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,43 +1004,47 @@ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
)


class ConvOpsBase(IRDLOperation, ABC):
class ConvOpsBase(NamedOpBase, ABC):
"""Base class for linalg convolution operations."""

inputs = var_operand_def()
outputs = var_operand_def(base(ShapedType))

res = var_result_def(AnyTensorType)

assembly_format = (
"attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` "
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
)
PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True

strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

def __init__(
self,
dilations: Attribute,
strides: Attribute,
inputs: Sequence[SSAValue],
outputs: Sequence[SSAValue] = (),
res: Sequence[Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
):
if res is None:
result_types = tuple(output.type for output in outputs)
else:
result_types = res
arg_types = self.body_arg_types((*inputs, *outputs))
add, mul = (
(arith.AddfOp, arith.MulfOp)
if isinstance(arg_types[-1], AnyFloat)
else (arith.AddiOp, arith.MuliOp)
)

@Builder.implicit_region(arg_types)
def hidden_region(args: tuple[BlockArgument, ...]) -> None:
if arg_types[0] != arg_types[-1]:
assert isinstance(arg_types[-1], IntegerType)
a = arith.ExtSIOp(args[0], arg_types[-1])
b = arith.ExtSIOp(args[1], arg_types[-1])
else:
a = args[0]
b = args[1]
result = mul(a, b)
mac = add(result, args[2])
YieldOp(mac)

super().__init__(
attributes={
"dilations": dilations,
"strides": strides,
},
operands=(inputs, outputs),
result_types=result_types,
ins=inputs,
outs=outputs,
attributes=attributes,
result_types=res,
hidden_region=hidden_region,
)


Expand Down

0 comments on commit 41b16cd

Please sign in to comment.