Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Introduce linalg.contract #123618

Merged
merged 11 commits into from
Jan 29, 2025
Merged

Conversation

rolfmorel
Copy link
Contributor

A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.

Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589

@llvmbot
Copy link
Member

llvmbot commented Jan 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Rolf Morel (rolfmorel)

Changes

A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.

Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589


Patch is 46.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123618.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+118)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+198-26)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+165-3)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir (+50)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+97)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+47)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+21-2)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+14-3)
  • (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+46)
  • (modified) mlir/test/Dialect/Linalg/transform-op-vectorize.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating on top of a third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting them to the same data type as the accumulator/output.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+                          outputs, attributes, regionBuilder);
+      }]>,
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare/implement functions necessary for LinalgStructuredInterface.
+    /// Infer iterator types for each dim in the domain of IndexingMaps.
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+    /// IndexingMaps always depends on attr associated to current Op instance.
+    bool hasDynamicIndexingMaps() { return true; };
+    bool hasUserDefinedMaps() { return true; };
+
+    static unsigned getNumRegionArgs();
+
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() {
+      return regionBuilder;
+    }
+
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement function necessary for DestinationStyleOpInterface.
+    ::mlir::MutableOperandRange getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..355ed2d269291c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3528,44 +3528,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+  if (parser.parseOptionalKeyword("indexing_maps"))
+    return {nullptr}; // Success in case indexing_maps was not provided.
 
-    if (parser.parseLSquare())
+  SmallVector<Attribute> indexingMaps;
+
+  auto parseIndexingMap = [&]() -> ParseResult {
+    AffineMapAttr affineMapAttr;
+    if (parser.parseAttribute(affineMapAttr))
       return failure();
+    indexingMaps.push_back(affineMapAttr);
+    return success();
+  };
 
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
+  if (parser.parseEqual() ||
+      parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                     parseIndexingMap))
+    return failure();
 
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
+  return parser.getBuilder().getArrayAttr(indexingMaps);
+}
 
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr))
+    return failure();
+
+  if (*indexingMapsAttr == nullptr) {
+    auto indexingMapAttrs = llvm::map_to_vector(
         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+    indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
   }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
 
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3599,6 +3600,7 @@ LogicalResult MatmulOp::verify() {
 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
+
 void MatmulOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  /// domains are all the same, and each implements a projected permutation.
+  /// Each dim in the domain must occur for at least one operand and is
+  /// classified as either batch, N-like, M-like, or K-like. Only the latter
+  /// corresponds to a reduction _and_ it is the only dim-kind which does not
+  /// occur for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_map' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
+
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and co-domain of "
+                         "corresponding affine_map differ");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+
+    return success();
+  };
+
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false}))
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NOTE: checking lambda will emit error.
+
+  bool hasContractingDim = false;
+  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+    if (inOccCount == 0)
+      return emitError("iteration space dim not used by either input");
+
+    // NB: A dim which occurs for only one input operand and not for the output.
+    //     In terms of einsum semantics, such dims have a sensible meaning -
+    //     namely an additional reduction per such dim - though this can also
+    //     always be expressed through an additional op. Additionally, at time
+    //     of writing, vector.contract's verifier accepts these dims but many of
+    //     its lowerings do not handle these kinds of dims. Hence...
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError("iter type of dim is not one of M, N, K or batch");
+  }
+
+  if (!hasContractingDim)
+    return emitError("'indexing_maps' do not specify a contracting dimension");
+
+  return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2, d0)>,
+                    affine_map<(d0, d1, d2) -> (d1, d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHEC...
[truncated]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the small scoped change. Much easier to review. I just have one question.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. No blockers for me. Please wait for others to take a look.

I would also highly recommend trying to make the parser/printer work with assemblyFormat. It will make things much better going forward.

@MaheshRavishankar MaheshRavishankar dismissed their stale review January 21, 2025 00:09

My main concern was addressed

@rolfmorel
Copy link
Contributor Author

Sorry @MaheshRavishankar, my initial take on your comment regarding regions was wrong. Have updated my comment up above.

Please let me know if that still addresses your concern.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dismissed prematurely. Just want to get clarification on the region of the op, but otherwise this looks good.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good 👍
Just a few minor nits

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
@rolfmorel
Copy link
Contributor Author

Thanks for looking over the PR, @adam-smnk ! I have now addressed your remarks - let me know what you think.

@MaheshRavishankar MaheshRavishankar dismissed their stale review January 22, 2025 02:40

Dismissing since the region handling will be done as a follow up

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
@rengolin rengolin requested a review from banach-space January 27, 2025 09:34
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff, thanks!

Mostly drive-by, though I do have one high level question. What about the scaling factor that is sometimes used for the accumulator matrix in GEMM implementations? Sorry if this was mentioned somewhere!

Also, parsing this on my tablet, so bear with me 😅

mlir/test/Dialect/Linalg/generalize-named-ops.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/generalize-named-ops.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/generalize-named-ops.mlir Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/transform-op-vectorize.mlir Outdated Show resolved Hide resolved
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Jan 27, 2025

Thanks for this round of comments, @banach-space ! I hope I have addressed all of the them now.

What about the scaling factor that is sometimes used for the accumulator matrix in GEMM implementations? Sorry if this was mentioned somewhere!

I don't think anybody brought it up before! If I am not mistaken though it suffers from the same issue as truncation (though in reverse): because we are dealing with a contraction, the op's associated body will run multiple times. How should we denote - in terms of the perfect loop nests correspondence for "structured ops" - that only on the first run of the body we should apply the scaling to C's elements? Phrased alternatively: what linalg.generic-form did you have in mind?

@rengolin
Copy link
Member

Thanks for this round of comments, @banach-space ! I hope I have addressed all of the them now.

What about the scaling factor that is sometimes used for the accumulator matrix in GEMM implementations? Sorry if this was mentioned somewhere!

I don't think anybody brought it up before! If I am not mistaken though it suffers from the same issue as truncation (though in reverse): because we are dealing with a contraction, the op's associated body will run multiple times. How should we denote - in terms of the perfect loop nests correspondence for "structured ops" - that only on the first run of the body we should apply the scaling to C's elements? Phrased alternatively: what linalg.generic-form did you have in mind?

If by "scaling factor" you mean Micro-Scaling (https://arxiv.org/pdf/2310.10537), then this will be added later across all linalg ops, or possibly as a separate type dialect (https://github.com/libxsmm/tpp-mlir/wiki/Microscaling-Data-Formats-in-MLIR). I'm not sure yet how this will work.

Currently, IIUC, MXFP implementations just separate the scaling from the matmul ops and tile them with a custom algorithm.

@banach-space
Copy link
Contributor

How should we denote - in terms of the perfect loop nests correspondence for "structured ops" - that only on the first run of the body we should apply the scaling to C's elements?

Ah, now I see what the problem is. Yeah, I don't see how that could be represented with a single GenericOp 🤔

If by "scaling factor" you mean Micro-Scaling (https://arxiv.org/pdf/2310.10537), then this will be added later across all linalg ops,

Oh, that's way ahead of what I was thinking 😅 But also important one to keep in mind.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the formatting updates, that's much appreciated! 🙏🏻

I've skimmed through some other bits and made small suggestions. All in all LG, in line with the RFC.

Btw, thank you for all the diligence and enormous effort that has gone into figuring this out! This will have quite an impact on Linalg, so no pressure 😂

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/invalid.mlir Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
@rolfmorel
Copy link
Contributor Author

Thank you for the thorough review, @banach-space !

I believe everything is addressed now. You could maybe have a look at the expanded docs and see if you are happy with what they say on broadcasting. Otherwise I think we are close to landing this PR.

Thanks again!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates to the docs, all that info will be super helpful! The stuff on bcast dims is now super clear 🙏🏻

I've left a couple of minor comments/nits, but these are non-blockers. I won't have any further comments so approving as is.

LGTM, great work!

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
@rolfmorel
Copy link
Contributor Author

Thanks @banach-space ! On your prompting, I have gone over the docs and comments once more and cleaned them up further. My thanks in general to the reviewers - it has improved the PR considerably!

As I am satisfied with the state of the PR, and I believe the others who commented are happy as well, I will merge the PR in a couple of hours (after a rebase and check-all). I will be on the ball when it comes to doing any potential fix-ups and PRs for minor extensions.

Thanks again to all those who participated in the RFC and on the PR!

A new op that allows for representing arbitrary contractions on operands
of arbitrary rank, with arbitrary transposes and arbitrary broadcasts
specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.
This reverts commit 3d0d5b3.

Actually, that line is needed...
@rolfmorel rolfmorel merged commit 0d4efa2 into llvm:main Jan 29, 2025
5 of 6 checks passed
@erichkeane
Copy link
Collaborator

I have clang-10 on my machine, and this doesn't compile!

See:

/local/home/ekeane/llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:3536:12: error: call to constructor of 'FailureOr<mlir::ArrayAttr>' is ambiguous
    return {nullptr}; // Success in case indexing_maps was not provided.
           ^~~~~~~~~
/local/home/ekeane/llvm-project/llvm/include/llvm/Support/LogicalResult.h:80:3: note: candidate constructor
  FailureOr(LogicalResult Result) {
  ^
/local/home/ekeane/llvm-project/llvm/include/llvm/Support/LogicalResult.h:85:3: note: candidate constructor
  FailureOr(T &&Y) : std::optional<T>(std::forward<T>(Y)) {}
  ^
/local/home/ekeane/llvm-project/llvm/include/llvm/Support/LogicalResult.h:86:3: note: candidate constructor
  FailureOr(const T &Y) : std::optional<T>(Y) {}
  ^
1 error generated.

@rengolin
Copy link
Member

Should be trivial to fix, no? Can you submit a fixup patch?

@rolfmorel
Copy link
Contributor Author

I had a look and it's not clear to me what the right way to fix this should be.

Might the following work?

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b33ba1cfb87d..db6d1d2c923b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3533,7 +3533,8 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
 
 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
   if (parser.parseOptionalKeyword("indexing_maps"))
-    return {nullptr}; // Success in case indexing_maps was not provided.
+    return {
+        (ArrayAttr) nullptr}; // Success in case indexing_maps was not provided.
 
   ArrayAttr arrayAttr;
   if (parser.parseEqual() || parser.parseAttribute(arrayAttr))

This compiles with clang-16 for me.

Note I will be AFK from my work machine until Saturday late. If you could try and submit a fix-up patch that would be much appreciated.

@erichkeane
Copy link
Collaborator

Woops, I ended up getting grabbed away immediately after posting this! I'll submit a patch for RAC very much like that one above as soon as it compiles.

erichkeane added a commit that referenced this pull request Jan 31, 2025
As reported in the PR #123618, 0d4efa2 included a
construction of a `FailureOr` object with a `nullptr`, which didn't work
in at least clang-10.  This patch changes it into a constructor call
instead of a brace-init call so that it is unambiguous.
@erichkeane
Copy link
Collaborator

Fixed, cf8c730

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants