From ff256d53ba6ba93fe68a6e17800485b0d2c6ad37 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Wed, 5 Feb 2025 09:51:27 +0100 Subject: [PATCH] dialects: (linalg) add hidden region to conv ops --- .../with-mlir/dialects/linalg/ops.mlir | 12 ++++- tests/interpreters/test_linalg_interpreter.py | 6 ++- xdsl/dialects/linalg.py | 52 ++++++++++--------- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 9d942b95d9..6b788a6e3e 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -59,7 +59,7 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) -%zero = arith.constant 0: f32 +%zero = arith.constant 0.0 : f32 linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -72,6 +72,14 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou %quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +%26, %27, %28 = "test.op"(): () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>) + +%conv_2d_nchw_i = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%26, %27: tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>) + outs(%28: tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32> + + + // CHECK-NEXT: #map = affine_map<(d0, d1) -> ()> // CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: module { @@ -117,4 +125,6 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> // CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %21:3 = "test.op"() : () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>) +// CHECK-NEXT: %22 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%21#0, %21#1 : tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>) outs(%21#2 : tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32> // CHECK-NEXT: } diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py index 542958d669..c19eea2417 100644 --- a/tests/interpreters/test_linalg_interpreter.py +++ b/tests/interpreters/test_linalg_interpreter.py @@ -337,14 +337,16 @@ def test_linalg_conv_2d_nchw_fchw(): interpreter = Interpreter(ModuleOp([])) interpreter.register_implementations(LinalgFunctions()) op = linalg.Conv2DNchwFchwOp( - DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]), - DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]), ( TestSSAValue(TensorType(f32, [1, 1, 5, 5])), TestSSAValue(TensorType(f32, [1, 1, 3, 3])), ), (TestSSAValue(TensorType(f32, [1, 1, 3, 3])),), (TensorType(f32, [1, 1, 3, 3]),), + { + "dilations": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]), + "strides": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]), + }, ) a = ShapedArray(TypedPtr.new_float32(list(range(25))), [1, 1, 5, 5]) b = ShapedArray( diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 9d3d4a6d87..51352294ad 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -1000,43 +1000,47 @@ class PoolingNchwMaxOp(PoolingOpsBase): name = "linalg.pooling_nchw_max" -class ConvOpsBase(IRDLOperation, ABC): +class ConvOpsBase(NamedOpBase, ABC): """Base class for linalg convolution operations.""" - inputs = var_operand_def() - outputs = var_operand_def(base(ShapedType)) - - res = var_result_def(AnyTensorType) - - assembly_format = ( - "attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` " - "`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)" - ) + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True strides = attr_def(DenseIntOrFPElementsAttr) dilations = attr_def(DenseIntOrFPElementsAttr) - irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] - def __init__( self, - dilations: Attribute, - strides: Attribute, inputs: Sequence[SSAValue], outputs: Sequence[SSAValue] = (), res: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, ): - if res is None: - result_types = tuple(output.type for output in outputs) - else: - result_types = res + arg_types = self.body_arg_types((*inputs, *outputs)) + add, mul = ( + (arith.AddfOp, arith.MulfOp) + if isinstance(arg_types[-1], AnyFloat) + else (arith.AddiOp, arith.MuliOp) + ) + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + if arg_types[0] != arg_types[-1]: + assert isinstance(arg_types[-1], IntegerType) + a = arith.ExtSIOp(args[0], arg_types[-1]) + b = arith.ExtSIOp(args[1], arg_types[-1]) + else: + a = args[0] + b = args[1] + result = mul(a, b) + mac = add(result, args[2]) + YieldOp(mac) + super().__init__( - attributes={ - "dilations": dilations, - "strides": strides, - }, - operands=(inputs, outputs), - result_types=result_types, + ins=inputs, + outs=outputs, + attributes=attributes, + result_types=res, + hidden_region=hidden_region, )