-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Linalg] Introduce linalg.contract #123618
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Rolf Morel (rolfmorel) ChangesA new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute. Supports the expected lowerings to linalg.generic and to vector.contract. Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589 Patch is 46.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123618.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+ let summary = [{
+ Perform a contraction on two inputs, accumulating on top of a third.
+ }];
+ let description = [{
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
+ output `D` is given by
+
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+ where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+ dimension identifiers (meant to range over valid indices), corresponding to
+ the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+ and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+ the dimensions in the set `dims`.
+
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
+ each dim is inferred and is either:
+
+ - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+ Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+ - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+ deriving from matmul terminology - is either an "M-like" dim (if in `A`
+ and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+ `B`, and `C`).
+
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+ `n` and `b` are of parallel iteration-type) and gets represented as:
+
+ ```
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ```
+
+ Note that by permuting the dims in the co-domains of the `affine_map`s, we
+ can apply arbitrary transposes to the inputs and output. Similarly,
+ arbitrary broadcasts can be achieved through leaving out dims on either
+ input operand.
+
+ Numeric casting is performed on the operands to the inner multiplication,
+ promoting them to the same data type as the accumulator/output.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ AffineMapArrayAttr:$indexing_maps
+ );
+ let results = (outs Variadic<AnyShaped>:$result_tensors);
+ let regions = (region SizedRegion<1>:$combiner);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+ outputs, attributes, regionBuilder);
+ }]>,
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, regionBuilder);
+ }]>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare/implement functions necessary for LinalgStructuredInterface.
+ /// Infer iterator types for each dim in the domain of IndexingMaps.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// IndexingMaps always depends on attr associated to current Op instance.
+ bool hasDynamicIndexingMaps() { return true; };
+ bool hasUserDefinedMaps() { return true; };
+
+ static unsigned getNumRegionArgs();
+
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ // Implement function necessary for DestinationStyleOpInterface.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..355ed2d269291c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3528,44 +3528,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
}
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<Attribute, 3> indexingMapsAttr;
- Attribute mapAttr;
- if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
- if (parser.parseEqual())
- return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+ if (parser.parseOptionalKeyword("indexing_maps"))
+ return {nullptr}; // Success in case indexing_maps was not provided.
- if (parser.parseLSquare())
+ SmallVector<Attribute> indexingMaps;
+
+ auto parseIndexingMap = [&]() -> ParseResult {
+ AffineMapAttr affineMapAttr;
+ if (parser.parseAttribute(affineMapAttr))
return failure();
+ indexingMaps.push_back(affineMapAttr);
+ return success();
+ };
- do {
- if (parser.parseAttribute(mapAttr))
- return failure();
- if (!isa<AffineMapAttr>(mapAttr)) {
- return parser.emitError(parser.getCurrentLocation(),
- "expected affine map attribute");
- }
- indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseEqual() ||
+ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+ parseIndexingMap))
+ return failure();
- if (parser.parseOptionalComma())
- break;
- } while (true);
+ return parser.getBuilder().getArrayAttr(indexingMaps);
+}
- if (parser.parseRSquare())
- return failure();
- }
- // Initialize indexingMaps, if not supplied explicitly.
- if (indexingMapsAttr.empty()) {
- indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr))
+ return failure();
+
+ if (*indexingMapsAttr == nullptr) {
+ auto indexingMapAttrs = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
}
- result.addAttribute("indexing_maps",
- parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
+
void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3599,6 +3600,7 @@ LogicalResult MatmulOp::verify() {
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
+
void MatmulOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+ AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+ /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+ /// domains are all the same, and each implements a projected permutation.
+ /// Each dim in the domain must occur for at least one operand and is
+ /// classified as either batch, N-like, M-like, or K-like. Only the latter
+ /// corresponds to a reduction _and_ it is the only dim-kind which does not
+ /// occur for the output operand. We use this fact for fast inference:
+ // NB: In case we allow dims to occur solely for one input, the above still
+ // holds: per the einsum semantics, these are reduction dims as well.
+ auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+ for (auto result : outAffineMap.getResults()) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(result);
+ assert(dimExpr && "affine_map is a projected permutation");
+ dimsInOutput[dimExpr.getPosition()] = true;
+ }
+
+ SmallVector<utils::IteratorType> iteratorTypes;
+ for (auto dimOccursInOutput : dimsInOutput)
+ iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+ : utils::IteratorType::reduction);
+
+ return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ContractOp regionBuilder expects 3 args");
+ RegionBuilderHelper helper(b, block);
+
+ TypeFn castSignedness = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castSignedness = attr.getValue();
+ }
+
+ // TODO: Support fields with operators besides mult & add.
+ Type outType = block.getArgument(2).getType();
+ Value lhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+ Value rhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+ Value productAtOutType =
+ helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+ Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ productAtOutType);
+ helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'indexing_map' attribute");
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ printNamedStructuredOp(
+ p, getOperation(), getInputs(), getOutputs(),
+ /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+ int iterationSpaceDims = -1;
+ // Maps iter space dim (as index) to num of occurrences in inputs and output.
+ SmallVector<size_t> inOccurrences;
+ SmallVector<size_t> outOccurrences;
+
+ auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+ bool isInput) -> LogicalResult {
+ if (iterationSpaceDims == -1) {
+ iterationSpaceDims = affineMap.getNumDims();
+ inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+ return emitError("iteration spaces of provided affine_maps differ");
+ }
+
+ if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+ if (affineMap.getNumResults() != shapedType.getRank())
+ return emitError("ranks of shaped operand and co-domain of "
+ "corresponding affine_map differ");
+ } else if (affineMap.getNumResults() != 0) {
+ return emitError("affine_map specifies shaped access while operand has "
+ "non-shaped type");
+ }
+
+ if (!affineMap.isProjectedPermutation())
+ return emitError("provided affine_map is not a projected permutation");
+
+ for (AffineExpr affineExpr : affineMap.getResults()) {
+ auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+ if (!affineDimExpr)
+ llvm_unreachable("affine_map is a projected permutation");
+
+ if (isInput)
+ inOccurrences[affineDimExpr.getPosition()] += 1;
+ else
+ outOccurrences[affineDimExpr.getPosition()] += 1;
+ }
+
+ return success();
+ };
+
+ for (auto &&[affineMap, operandType, isInput] :
+ llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+ SmallVector<bool>{true, true, false}))
+ if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+ return failure(); // NOTE: checking lambda will emit error.
+
+ bool hasContractingDim = false;
+ for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+ hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+ if (inOccCount == 0)
+ return emitError("iteration space dim not used by either input");
+
+ // NB: A dim which occurs for only one input operand and not for the output.
+ // In terms of einsum semantics, such dims have a sensible meaning -
+ // namely an additional reduction per such dim - though this can also
+ // always be expressed through an additional op. Additionally, at time
+ // of writing, vector.contract's verifier accepts these dims but many of
+ // its lowerings do not handle these kinds of dims. Hence...
+ // TODO: Remove following once we have comprehensive support for input-only
+ // reduction dims, at both the linalg- and vector-dialect levels.
+ if (inOccCount == 1 && outOccCount != 1)
+ return emitError("iter type of dim is not one of M, N, K or batch");
+ }
+
+ if (!hasContractingDim)
+ return emitError("'indexing_maps' do not specify a contracting dimension");
+
+ return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHEC...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the small scoped change. Much easier to review. I just have one question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. No blockers for me. Please wait for others to take a look.
I would also highly recommend trying to make the parser/printer work with assemblyFormat
. It will make things much better going forward.
Sorry @MaheshRavishankar, my initial take on your comment regarding regions was wrong. Have updated my comment up above. Please let me know if that still addresses your concern. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dismissed prematurely. Just want to get clarification on the region of the op, but otherwise this looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good 👍
Just a few minor nits
Thanks for looking over the PR, @adam-smnk ! I have now addressed your remarks - let me know what you think. |
Dismissing since the region handling will be done as a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff, thanks!
Mostly drive-by, though I do have one high level question. What about the scaling factor that is sometimes used for the accumulator matrix in GEMM implementations? Sorry if this was mentioned somewhere!
Also, parsing this on my tablet, so bear with me 😅
Thanks for this round of comments, @banach-space ! I hope I have addressed all of the them now.
I don't think anybody brought it up before! If I am not mistaken though it suffers from the same issue as truncation (though in reverse): because we are dealing with a contraction, the op's associated body will run multiple times. How should we denote - in terms of the perfect loop nests correspondence for "structured ops" - that only on the first run of the body we should apply the scaling to C's elements? Phrased alternatively: what |
If by "scaling factor" you mean Micro-Scaling (https://arxiv.org/pdf/2310.10537), then this will be added later across all linalg ops, or possibly as a separate type dialect (https://github.com/libxsmm/tpp-mlir/wiki/Microscaling-Data-Formats-in-MLIR). I'm not sure yet how this will work. Currently, IIUC, MXFP implementations just separate the scaling from the matmul ops and tile them with a custom algorithm. |
Ah, now I see what the problem is. Yeah, I don't see how that could be represented with a single GenericOp 🤔
Oh, that's way ahead of what I was thinking 😅 But also important one to keep in mind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the formatting updates, that's much appreciated! 🙏🏻
I've skimmed through some other bits and made small suggestions. All in all LG, in line with the RFC.
Btw, thank you for all the diligence and enormous effort that has gone into figuring this out! This will have quite an impact on Linalg, so no pressure 😂
Thank you for the thorough review, @banach-space ! I believe everything is addressed now. You could maybe have a look at the expanded docs and see if you are happy with what they say on broadcasting. Otherwise I think we are close to landing this PR. Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the updates to the docs, all that info will be super helpful! The stuff on bcast dims is now super clear 🙏🏻
I've left a couple of minor comments/nits, but these are non-blockers. I won't have any further comments so approving as is.
LGTM, great work!
Thanks @banach-space ! On your prompting, I have gone over the docs and comments once more and cleaned them up further. My thanks in general to the reviewers - it has improved the PR considerably! As I am satisfied with the state of the PR, and I believe the others who commented are happy as well, I will merge the PR in a couple of hours (after a rebase and Thanks again to all those who participated in the RFC and on the PR! |
e8c3c1b
to
ca65e0c
Compare
A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute. Supports the expected lowerings to linalg.generic and to vector.contract.
This reverts commit 3d0d5b3. Actually, that line is needed...
ca65e0c
to
072da4b
Compare
I have clang-10 on my machine, and this doesn't compile! See:
|
Should be trivial to fix, no? Can you submit a fixup patch? |
I had a look and it's not clear to me what the right way to fix this should be. Might the following work?
This compiles with Note I will be AFK from my work machine until Saturday late. If you could try and submit a fix-up patch that would be much appreciated. |
Woops, I ended up getting grabbed away immediately after posting this! I'll submit a patch for RAC very much like that one above as soon as it compiles. |
Fixed, cf8c730 |
A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute.
Supports the expected lowerings to linalg.generic and to vector.contract.
Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589