-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Extend Linalg elemwise named ops semantics #122753
Conversation
Implements Linalg elemwise named-op following the proposal and discussions in RFC: https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1 Discussions are on-going on RFC especially about `comp_type` and so that part is left open/unimplemented in this diff.
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesImplements Linalg elemwise named-op following the Discussions are on-going on RFC especially about Patch is 33.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122753.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..115eaebc6aff54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
}];
}
+// Define the enum-type Elemwise func attribute.
+def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..5135e9cd4386ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -15,6 +15,57 @@
include "mlir/IR/EnumAttr.td"
+// Define an `enum class : i32` to categorise element-wise op.
+def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
+ I32EnumAttrCase<"Unary", 0>,
+ I32EnumAttrCase<"Binary", 1>,
+ I32EnumAttrCase<"Ternary", 2>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define a unified `enum class : i32` for all element-wise options.
+// Note: The order of individual fn (e.g. 'exp', 'log') within each
+// category (Unary, Binary etc.) must match the ordering of same fn
+// defined in UnaryFn, BinaryFn. This is to enable correct mapping
+// from this unified enum class to different category enums.
+def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
+ // Unary
+ I32EnumAttrCase<"exp", 0>,
+ I32EnumAttrCase<"log", 1>,
+ I32EnumAttrCase<"abs", 2>,
+ I32EnumAttrCase<"ceil", 3>,
+ I32EnumAttrCase<"floor", 4>,
+ I32EnumAttrCase<"negf", 5>,
+ I32EnumAttrCase<"reciprocal", 6>,
+ I32EnumAttrCase<"round", 7>,
+ I32EnumAttrCase<"sqrt", 8>,
+ I32EnumAttrCase<"rsqrt", 9>,
+ I32EnumAttrCase<"square", 10>,
+ I32EnumAttrCase<"tanh", 11>,
+ I32EnumAttrCase<"erf", 12>,
+
+ // Binary
+
+ I32EnumAttrCase<"add", 13>,
+ I32EnumAttrCase<"sub", 14>,
+ I32EnumAttrCase<"mul", 15>,
+ I32EnumAttrCase<"div", 16>,
+ I32EnumAttrCase<"div_unsigned", 17>,
+ I32EnumAttrCase<"max_signed", 18>,
+ I32EnumAttrCase<"min_signed", 19>,
+ I32EnumAttrCase<"max_unsigned", 20>,
+ I32EnumAttrCase<"min_unsigned", 21>,
+ I32EnumAttrCase<"powf", 22>,
+
+ // Ternary
+ I32EnumAttrCase<"select", 23>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..6d6ff7a5c7872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+//===----------------------------------------------------------------------===//
+
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+ AttrSizedOperandSegments]> {
+ let summary = [{ Performs element-wise operation }];
+ let description = [{
+ Linalg op form which performs element-wise computation. The attribute
+ `func_type` describes the operation type (e.g. add, exp). The func_type
+ can be any valid unary, binary, or ternary operation.
+
+ Affine-maps for operands and result may be provided by the user. When
+ a user-defined indexing_map is not provided, identity map is inferred
+ for all operands. The default indexing maps are N identity-maps. ‘N’
+ depends on the arity of the elementwise op. The number of dims is
+ inferred from rank of the output type. In the case of default indexing
+ map, the input and output shapes must all match. Affine-map for operands
+ and result must be only projected permutations with no zero constants.
+
+ For element-wise iterator-type is always inferred as all ‘parallel’.
+ Iterator-type is needed for constructing this underlying structured op.
+ The number of dims of the iterator-type is inferred from the rank of
+ the result type.
+
+ Example:
+ Defining a unary linalg.elemwise with default indexing-map:
+
+ ```mlir
+ %exp = linalg.elemwise
+ func_type=#linalg.elemwise_fn<exp>
+ ins(%x : tensor<4x16x8xf32>)
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+ ```
+
+ Defining a binary linalg.elemwise with user-defined indexing-map:
+
+ ```mlir
+ %add = linalg.elemwise
+ func_type=#linalg.elemwise_fn<add>
+ indexing_maps = [#transpose, #broadcast, #identity]
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ElemwiseFnAttr:$func_type,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElemwiseOp::getRegionBuilder());
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+ /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+ static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+
+ /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+ /// or `default`. Default is identity-maps.
+ static bool hasDynamicIndexingMaps() { return true; }
+
+ /// Implements the block region builder for the eemwiseOp. This is called
+ /// by the 'fillStructuredOpRegion'.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+ ElemwiseFn getElemwiseFnVal() {
+ return getFuncType();
+ }
+
+ /// Infer dimensionality of the `iteration space` from the result type.
+ /// Useful when others means are not possible e.g. in case of absence of
+ /// user-provided indexing map.
+ unsigned getResultRank();
+
+ /// Elementwise op does not have to explicitly specify iterator type
+ /// as it is always 'parallel'. The number of 'parallel' loops is
+ /// inferred from other means (e.g. result tensor type).
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// The default indexing maps are N identity-maps. 'N' depends on the
+ /// arity of the elementwise op. The default case is when all input
+ /// output tensors are same rank and no transpose/broadcast is needed.
+ static SmallVector<AffineMap>
+ getDefaultIndexingMaps(unsigned N, unsigned numDims,
+ MLIRContext *context);
+
+ /// Returns true if the user defined indexing maps are not equal to
+ /// the default (identity) map.
+ bool hasUserDefinedMaps();
+
+ /// destination passing style interface method.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..c84220f5b4f2ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildElemwiseOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
+
void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3611,5 +3621,282 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ElemwiseOp - with support for affine map, func_type and comp_type
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct NAryCategoryAndFn {
+ // The enum category class {Unary, Binary, Ternary, ..}
+ ElemwiseNAryCategory category;
+
+ union NAryFn {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } fn;
+
+ ::llvm::StringRef stringifyCategory() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return "unary";
+ case ElemwiseNAryCategory::Binary:
+ return "binary";
+ case ElemwiseNAryCategory::Ternary:
+ return "ternary";
+ }
+ llvm_unreachable("unknown-category");
+ }
+
+ ::llvm::StringRef stringifyFn() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return stringifyUnaryFn(fn.unaryFn);
+ case ElemwiseNAryCategory::Binary:
+ return stringifyBinaryFn(fn.binaryFn);
+ case ElemwiseNAryCategory::Ternary:
+ return stringifyTernaryFn(fn.ternaryFn);
+ }
+ llvm_unreachable("unknown-fn");
+ }
+};
+
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return 1;
+ case ElemwiseNAryCategory::Binary:
+ return 2;
+ case ElemwiseNAryCategory::Ternary:
+ return 3;
+ }
+ llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+ constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+ constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+ constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+
+ int val = static_cast<int>(fn);
+ NAryCategoryAndFn result;
+ if (val <= lastUnary) {
+ result.category = ElemwiseNAryCategory::Unary;
+ result.fn.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val <= lastBinary) {
+ result.category = ElemwiseNAryCategory::Binary;
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+ return result;
+ }
+ if (val > lastTernary) {
+ llvm_unreachable("unhandled ElemwiseFn");
+ }
+ result.category = ElemwiseNAryCategory::Ternary;
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+ return result;
+}
+
+unsigned ElemwiseOp::getResultRank() {
+ auto output = getDpsInitOperand(0)->get();
+ auto shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+}
+
+bool ElemwiseOp::hasUserDefinedMaps() {
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+
+ auto numDims = getResultRank();
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(arity + 1, numDims, this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `func_type = #linalg.elemwise_fn<add>`
+ Attribute attr;
+ mlir::linalg::ElemwiseFn elemwiseFnVal;
+ if (parser.parseKeyword("func_type"))
+ return failure();
+ if (parser.parseEqual())
+ return failure();
+
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseFnAttr = dyn_cast<ElemwiseFnAttr>(attr);
+ if (!elemwiseFnAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElemwiseFn attribute");
+ elemwiseFnVal = elemwiseFnAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'func_type' attribute");
+ }
+ result.addAttribute("func_type",
+ ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+ auto arity = getArityFromCategory(arityAndCategory.category);
+ int numRegionArgs = arity + 1 /*output*/;
+
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElemwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the `number of indexing maps` needed from the result
+ // type which is already parsed by now.
+ auto resultType = result.operands[result.operands.size() - 1].getType();
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "return type needs to be shaped type");
+ auto numDims = shapedType.getRank();
+ indexingMapsAttr = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+ parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return success();
+}
+
+void ElemwiseOp::print(OpAsmPrinter &p) {
+ p << " func_type=";
+ p.printAttribute(getFuncTypeAttr());
+
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "func_type",
+ "indexing_maps"};
+
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+
+ auto numDims = getResultRank();
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+LogicalResult ElemwiseOp::verify() {
+ // All necessary checks are done either by
+ // - EnumAttr (e.g. unknown func_type)
+ // - verifyStructuredOpInterface (incorrect map, sizes).
+ return success();
+}
+
+/// Implements the block region builder for the ElemwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ ElemwiseFn elemwiseFn;
+ for (auto attr : attrs) {
+ if (attr.getName() == b.getStringAttr("func_type")) {
+ auto funcTypeAttr = dyn_cast<ElemwiseFnAttr>(attr.getValue());
+ assert(funcTypeAttr && "func_type attribute incorrectly set");
+ elemwiseFn = funcTypeAttr.getValue();
+ break;
+ }
+ }
+
+ NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+ ElemwiseNAryCategory category = categoryAndFn.category;
+
+ unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+ assert(block.getNumArguments() == numBlockArgs &&
+ "Elemwise regionBuilder number of block args mismatch");
+
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+ Value result;
+
+ if (category == ElemwiseNAryCategory::Unary) {
+ result =
+ helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+
+ } else if (category == ElemwiseNAryCategory::Binary) {
+ result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+ block.getArgument(0), block.getArgument(1));
+ } else if (category == ElemwiseNAryCategory::Ternary) {
+ result = helper.buildTernaryFn(categoryAndFn.fn.ternaryFn,
+ block.getArgument(0), block.getArgument(1), block.getArgument(2));
+ } else
+ assert(false && "found unhandled category in elemwise print");
+
+ yields.push_back(result);
+ helper.yieldOutputs(yields);
+}
+
+LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void ElemwiseOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElemwiseOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir
new file mode 100644
index 00...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesImplements Linalg elemwise named-op following the Discussions are on-going on RFC especially about Patch is 33.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122753.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..115eaebc6aff54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
}];
}
+// Define the enum-type Elemwise func attribute.
+def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..5135e9cd4386ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -15,6 +15,57 @@
include "mlir/IR/EnumAttr.td"
+// Define an `enum class : i32` to categorise element-wise op.
+def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
+ I32EnumAttrCase<"Unary", 0>,
+ I32EnumAttrCase<"Binary", 1>,
+ I32EnumAttrCase<"Ternary", 2>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define a unified `enum class : i32` for all element-wise options.
+// Note: The order of individual fn (e.g. 'exp', 'log') within each
+// category (Unary, Binary etc.) must match the ordering of same fn
+// defined in UnaryFn, BinaryFn. This is to enable correct mapping
+// from this unified enum class to different category enums.
+def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
+ // Unary
+ I32EnumAttrCase<"exp", 0>,
+ I32EnumAttrCase<"log", 1>,
+ I32EnumAttrCase<"abs", 2>,
+ I32EnumAttrCase<"ceil", 3>,
+ I32EnumAttrCase<"floor", 4>,
+ I32EnumAttrCase<"negf", 5>,
+ I32EnumAttrCase<"reciprocal", 6>,
+ I32EnumAttrCase<"round", 7>,
+ I32EnumAttrCase<"sqrt", 8>,
+ I32EnumAttrCase<"rsqrt", 9>,
+ I32EnumAttrCase<"square", 10>,
+ I32EnumAttrCase<"tanh", 11>,
+ I32EnumAttrCase<"erf", 12>,
+
+ // Binary
+
+ I32EnumAttrCase<"add", 13>,
+ I32EnumAttrCase<"sub", 14>,
+ I32EnumAttrCase<"mul", 15>,
+ I32EnumAttrCase<"div", 16>,
+ I32EnumAttrCase<"div_unsigned", 17>,
+ I32EnumAttrCase<"max_signed", 18>,
+ I32EnumAttrCase<"min_signed", 19>,
+ I32EnumAttrCase<"max_unsigned", 20>,
+ I32EnumAttrCase<"min_unsigned", 21>,
+ I32EnumAttrCase<"powf", 22>,
+
+ // Ternary
+ I32EnumAttrCase<"select", 23>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..6d6ff7a5c7872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+//===----------------------------------------------------------------------===//
+
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+ AttrSizedOperandSegments]> {
+ let summary = [{ Performs element-wise operation }];
+ let description = [{
+ Linalg op form which performs element-wise computation. The attribute
+ `func_type` describes the operation type (e.g. add, exp). The func_type
+ can be any valid unary, binary, or ternary operation.
+
+ Affine-maps for operands and result may be provided by the user. When
+ a user-defined indexing_map is not provided, identity map is inferred
+ for all operands. The default indexing maps are N identity-maps. ‘N’
+ depends on the arity of the elementwise op. The number of dims is
+ inferred from rank of the output type. In the case of default indexing
+ map, the input and output shapes must all match. Affine-map for operands
+ and result must be only projected permutations with no zero constants.
+
+ For element-wise iterator-type is always inferred as all ‘parallel’.
+ Iterator-type is needed for constructing this underlying structured op.
+ The number of dims of the iterator-type is inferred from the rank of
+ the result type.
+
+ Example:
+ Defining a unary linalg.elemwise with default indexing-map:
+
+ ```mlir
+ %exp = linalg.elemwise
+ func_type=#linalg.elemwise_fn<exp>
+ ins(%x : tensor<4x16x8xf32>)
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+ ```
+
+ Defining a binary linalg.elemwise with user-defined indexing-map:
+
+ ```mlir
+ %add = linalg.elemwise
+ func_type=#linalg.elemwise_fn<add>
+ indexing_maps = [#transpose, #broadcast, #identity]
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ElemwiseFnAttr:$func_type,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElemwiseOp::getRegionBuilder());
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+ /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+ static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+
+ /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+ /// or `default`. Default is identity-maps.
+ static bool hasDynamicIndexingMaps() { return true; }
+
+ /// Implements the block region builder for the eemwiseOp. This is called
+ /// by the 'fillStructuredOpRegion'.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+ ElemwiseFn getElemwiseFnVal() {
+ return getFuncType();
+ }
+
+ /// Infer dimensionality of the `iteration space` from the result type.
+ /// Useful when others means are not possible e.g. in case of absence of
+ /// user-provided indexing map.
+ unsigned getResultRank();
+
+ /// Elementwise op does not have to explicitly specify iterator type
+ /// as it is always 'parallel'. The number of 'parallel' loops is
+ /// inferred from other means (e.g. result tensor type).
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// The default indexing maps are N identity-maps. 'N' depends on the
+ /// arity of the elementwise op. The default case is when all input
+ /// output tensors are same rank and no transpose/broadcast is needed.
+ static SmallVector<AffineMap>
+ getDefaultIndexingMaps(unsigned N, unsigned numDims,
+ MLIRContext *context);
+
+ /// Returns true if the user defined indexing maps are not equal to
+ /// the default (identity) map.
+ bool hasUserDefinedMaps();
+
+ /// destination passing style interface method.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..c84220f5b4f2ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildElemwiseOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
+
void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3611,5 +3621,282 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ElemwiseOp - with support for affine map, func_type and comp_type
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct NAryCategoryAndFn {
+ // The enum category class {Unary, Binary, Ternary, ..}
+ ElemwiseNAryCategory category;
+
+ union NAryFn {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } fn;
+
+ ::llvm::StringRef stringifyCategory() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return "unary";
+ case ElemwiseNAryCategory::Binary:
+ return "binary";
+ case ElemwiseNAryCategory::Ternary:
+ return "ternary";
+ }
+ llvm_unreachable("unknown-category");
+ }
+
+ ::llvm::StringRef stringifyFn() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return stringifyUnaryFn(fn.unaryFn);
+ case ElemwiseNAryCategory::Binary:
+ return stringifyBinaryFn(fn.binaryFn);
+ case ElemwiseNAryCategory::Ternary:
+ return stringifyTernaryFn(fn.ternaryFn);
+ }
+ llvm_unreachable("unknown-fn");
+ }
+};
+
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return 1;
+ case ElemwiseNAryCategory::Binary:
+ return 2;
+ case ElemwiseNAryCategory::Ternary:
+ return 3;
+ }
+ llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+ constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+ constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+ constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+
+ int val = static_cast<int>(fn);
+ NAryCategoryAndFn result;
+ if (val <= lastUnary) {
+ result.category = ElemwiseNAryCategory::Unary;
+ result.fn.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val <= lastBinary) {
+ result.category = ElemwiseNAryCategory::Binary;
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+ return result;
+ }
+ if (val > lastTernary) {
+ llvm_unreachable("unhandled ElemwiseFn");
+ }
+ result.category = ElemwiseNAryCategory::Ternary;
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+ return result;
+}
+
+unsigned ElemwiseOp::getResultRank() {
+ auto output = getDpsInitOperand(0)->get();
+ auto shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+}
+
+bool ElemwiseOp::hasUserDefinedMaps() {
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+
+ auto numDims = getResultRank();
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(arity + 1, numDims, this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `func_type = #linalg.elemwise_fn<add>`
+ Attribute attr;
+ mlir::linalg::ElemwiseFn elemwiseFnVal;
+ if (parser.parseKeyword("func_type"))
+ return failure();
+ if (parser.parseEqual())
+ return failure();
+
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseFnAttr = dyn_cast<ElemwiseFnAttr>(attr);
+ if (!elemwiseFnAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElemwiseFn attribute");
+ elemwiseFnVal = elemwiseFnAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'func_type' attribute");
+ }
+ result.addAttribute("func_type",
+ ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+ auto arity = getArityFromCategory(arityAndCategory.category);
+ int numRegionArgs = arity + 1 /*output*/;
+
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElemwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the `number of indexing maps` needed from the result
+ // type which is already parsed by now.
+ auto resultType = result.operands[result.operands.size() - 1].getType();
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "return type needs to be shaped type");
+ auto numDims = shapedType.getRank();
+ indexingMapsAttr = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+ parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return success();
+}
+
+void ElemwiseOp::print(OpAsmPrinter &p) {
+ p << " func_type=";
+ p.printAttribute(getFuncTypeAttr());
+
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "func_type",
+ "indexing_maps"};
+
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+
+ auto numDims = getResultRank();
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+LogicalResult ElemwiseOp::verify() {
+ // All necessary checks are done either by
+ // - EnumAttr (e.g. unknown func_type)
+ // - verifyStructuredOpInterface (incorrect map, sizes).
+ return success();
+}
+
+/// Implements the block region builder for the ElemwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ ElemwiseFn elemwiseFn;
+ for (auto attr : attrs) {
+ if (attr.getName() == b.getStringAttr("func_type")) {
+ auto funcTypeAttr = dyn_cast<ElemwiseFnAttr>(attr.getValue());
+ assert(funcTypeAttr && "func_type attribute incorrectly set");
+ elemwiseFn = funcTypeAttr.getValue();
+ break;
+ }
+ }
+
+ NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+ ElemwiseNAryCategory category = categoryAndFn.category;
+
+ unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+ assert(block.getNumArguments() == numBlockArgs &&
+ "Elemwise regionBuilder number of block args mismatch");
+
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+ Value result;
+
+ if (category == ElemwiseNAryCategory::Unary) {
+ result =
+ helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+
+ } else if (category == ElemwiseNAryCategory::Binary) {
+ result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+ block.getArgument(0), block.getArgument(1));
+ } else if (category == ElemwiseNAryCategory::Ternary) {
+ result = helper.buildTernaryFn(categoryAndFn.fn.ternaryFn,
+ block.getArgument(0), block.getArgument(1), block.getArgument(2));
+ } else
+ assert(false && "found unhandled category in elemwise print");
+
+ yields.push_back(result);
+ helper.yieldOutputs(yields);
+}
+
+LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void ElemwiseOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElemwiseOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir
new file mode 100644
index 00...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the first pass, already looks really good 👍
auto arity = getArityFromCategory(category); | ||
|
||
auto numDims = getResultRank(); | ||
SmallVector<AffineMap, 3> defaultMaps = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it can be just SmallVector<AffineMap>
as size changes with arity
I32EnumAttrCase<"erf", 12>, | ||
|
||
// Binary | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: unnecessary extra new line
/// or `default`. Default is identity-maps. | ||
static bool hasDynamicIndexingMaps() { return true; } | ||
|
||
/// Implements the block region builder for the eemwiseOp. This is called |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Implements the block region builder for the eemwiseOp. This is called | |
/// Implements the block region builder for the elemwiseOp. This is called |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didnt look at all the details, but left a few comments. Structure overall seems fine to me.
// Op definition for ElemwiseOp - with user-defined maps, computation type etc. | ||
//===----------------------------------------------------------------------===// | ||
|
||
def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd rather just use ElementWiseOp
and element_wise
. But totally my preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for full spelling, feels more intuitive
nit to nit: I'd go with a single word version ElementwiseOp
and elementwise
, it's also more consistent with spelling across repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd rather just use
ElementWiseOp
andelement_wise
. But totally my preference.
I too agree. I just wasnt sure what you folks would prefer (longer IR and precise, or shorter and guess-it)
@@ -15,6 +15,57 @@ | |||
|
|||
include "mlir/IR/EnumAttr.td" | |||
|
|||
// Define an `enum class : i32` to categorise element-wise op. | |||
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be frowned upon, but I dont think we need two different sets of enums. We could just use a range of enums in ElemwiseFn
below to denote unary/binary and ternary functions....
return defaultMaps != explicitMaps; | ||
} | ||
|
||
ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as what I left on the matmul op PR. I think this could be done with assembly format with just custom parser/printer functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @javedabsar1 , excited to see this materialising 🙇🏻 🙏🏻
Some minor comments in tests while I go over the other details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a new directory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because i reckon the number of tests will increases over time as elementwise develops further.
// ----- | ||
|
||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
func.func @fewer_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit]
func.func @fewer_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { | |
func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { |
// ----- | ||
|
||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
func.func @incorrect_transpose_map(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit]
func.func @incorrect_transpose_map(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { | |
func.func @identitiy_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { |
// CHECK: @multiple_dims(%[[A]]: tensor<1x2x3x4x5xi32>, %[[B:.+]]: tensor<1x2x3x4x5xi32>, %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { | ||
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>) outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> | ||
// | ||
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Please use consistent naming
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { | |
func.func @multiple_dims(%A : tensor<1x2x3x4x5xi32>, %B : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { |
%r = linalg.elemwise func_type=#linalg.elemwise_fn<mul> | ||
indexing_maps = [#map, #map, #map] | ||
ins(%arg0, %arg1: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>) | ||
outs(%arg2: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> | ||
return %r : tensor<1x2x3x4x5xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indentation is different to the rest of the file. Please unify. Personally I prefer this "reduced" indentation.
// CHECK: @binary_mul(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { | ||
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32> | ||
// | ||
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] To make a clearer distinction from multiple_dims
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { | |
func.func @binary_mu_2Dl(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
er, I think you mean binary_mul_2D
😆
// CHECK: @multiple_dims(%[[A]]: tensor<1x2x3x4x5xi32>, %[[B:.+]]: tensor<1x2x3x4x5xi32>, %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { | ||
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>) outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> | ||
// | ||
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] To make a clearer distinction from @binary_mul
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { | |
func.func @binary_mul_5D(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Could you make sure that the test function naming is consistent in this file? From what I can see, it goes like this:
@{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
I'm not sure how to handle on_tensor
and on_memref
😅
High-level question for everyone: Please feel free to shut this thread down if you think this isn't worth discussing (to avoid bike-shedding). To me, func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elemwise
func_type=#linalg.elemwise_fn<exp>
ins(%A : tensor<8x16x32xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
} Would something like Ultimately, it seems to describe "a mathematical operation inside the Linalg Op body." Do these alternatives make sense, or do you feel |
To me, they all (together with Nothing better comes to my mind but I wouldn't mind if we could align the attr name with something already existing in the wild. Maybe even sth from PyTorch or IREE? Such that most people could recognize it more intuitively. |
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ | |||
let hasCanonicalizer = 1; | |||
} | |||
|
|||
//===----------------------------------------------------------------------===// | |||
// Op definition for ElemwiseOp - with user-defined maps, computation type etc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Lets stick with the existing convention.
// Op definition for ElemwiseOp - with user-defined maps, computation type etc. | |
// Op definition for ElemwiseOp |
let description = [{ | ||
Linalg op form which performs element-wise computation. The attribute | ||
`func_type` describes the operation type (e.g. add, exp). The func_type | ||
can be any valid unary, binary, or ternary operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we defining "any valid"? I would either list all or skip this altogether.
for all operands. The default indexing maps are N identity-maps. ‘N’ | ||
depends on the arity of the elementwise op. The number of dims is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't N simply the rank of the output? And doesn't "arity" mean the number of function args? Is that relevant here?
inferred from rank of the output type. In the case of default indexing | ||
map, the input and output shapes must all match. Affine-map for operands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inferred from rank of the output type. In the case of default indexing | |
map, the input and output shapes must all match. Affine-map for operands | |
inferred from rank of the output type. In the case of default indexing | |
maps, the input and output shapes must match. Affine-maps for operands |
depends on the arity of the elementwise op. The number of dims is | ||
inferred from rank of the output type. In the case of default indexing | ||
map, the input and output shapes must all match. Affine-map for operands | ||
and result must be only projected permutations with no zero constants. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and result must be only projected permutations with no zero constants. | |
and result must be projected permutations with no zero constants. |
/// the default (identity) map. | ||
bool hasUserDefinedMaps(); | ||
|
||
/// destination passing style interface method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// destination passing style interface method. | |
/// Destination passing style interface method. |
/// Infer dimensionality of the `iteration space` from the result type. | ||
/// Useful when others means are not possible e.g. in case of absence of | ||
/// user-provided indexing map. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this simply return the rank of the result? There's isn't much "inference" involved 😅 As per your implementation, return shapedType.getRank();
.
/// Elementwise op does not have to explicitly specify iterator type | ||
/// as it is always 'parallel'. The number of 'parallel' loops is | ||
/// inferred from other means (e.g. result tensor type). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is already included the general Op description above:
/// Elementwise op does not have to explicitly specify iterator type /// as it is always 'parallel'
Please only document what getIteratorTypesArray
does.
/// The default indexing maps are N identity-maps. 'N' depends on the | ||
/// arity of the elementwise op. The default case is when all input | ||
/// output tensors are same rank and no transpose/broadcast is needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to refine this, to me "N identity-maps" could mean "identity-maps with N dims". Perhaps,
The default indexing maps are identities. There will be
N
such maps, whereN
is the arity of the Op.
`func_type` describes the operation type (e.g. add, exp). The func_type | ||
can be any valid unary, binary, or ternary operation. | ||
|
||
Affine-maps for operands and result may be provided by the user. When |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Affine-maps for operands and result may be provided by the user.
IIUC, maps are required when transposing or broadcasting?
Whatever we choose to call it, we need to keep in mind that we will have two attributes that should be consistently named : How about |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have tests that show binary ops with single arguments / unary ops with more than one being rejected?
// CHECK: @binary_mul(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { | ||
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32> | ||
// | ||
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
er, I think you mean binary_mul_2D
😆
IIUC, we don't need "compute type" just now? Shall we cross that bridge when we get there and in the meantime prioritise consistency and use |
@@ -15,6 +15,57 @@ | |||
|
|||
include "mlir/IR/EnumAttr.td" | |||
|
|||
// Define an `enum class : i32` to categorise element-wise op. | |||
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [ | |||
I32EnumAttrCase<"Unary", 0>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need separate enums here? Can we juse use the enum below to classify what is a unary/binary or ternary function? (I think this is a repeat comment from above)
); | ||
|
||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors); | ||
let regions = (region AnyRegion:$region); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a region on the op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to cross-pollinate the Linalg discussions.
It looks like regions are required by Linalg's interface. Thus, all ops currently have and should have a region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple of other comments. I think it would be useful to not have regions on the operations since they seem unnecessary. Apart from that this looks good to me.
created a new PR -- #124661 |
Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
Discussions are on-going on RFC especially about
comp_type
and so that part is left open/unimplemented in this diff.