Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Extend Linalg elemwise named ops semantics #122753

Closed
wants to merge 4 commits into from

Conversation

javedabsar1
Copy link
Contributor

Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

Discussions are on-going on RFC especially about
comp_type and so that part is left open/unimplemented in this diff.

Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

Discussions are on-going on RFC especially about
`comp_type` and so that part is left open/unimplemented in this diff.
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

Discussions are on-going on RFC especially about
comp_type and so that part is left open/unimplemented in this diff.


Patch is 33.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122753.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+51)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+130)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+287)
  • (added) mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir (+170)
  • (added) mlir/test/Dialect/Linalg/elemwise/invalid.mlir (+39)
  • (added) mlir/test/Dialect/Linalg/elemwise/round-trip.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..115eaebc6aff54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
+// Define the enum-type Elemwise func attribute.
+def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
   let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..5135e9cd4386ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -15,6 +15,57 @@
 
 include "mlir/IR/EnumAttr.td"
 
+// Define an `enum class : i32` to categorise element-wise op.
+def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
+  I32EnumAttrCase<"Unary", 0>,
+  I32EnumAttrCase<"Binary", 1>,
+  I32EnumAttrCase<"Ternary", 2>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
+// Define a unified `enum class : i32` for all element-wise options.
+// Note: The order of individual fn (e.g. 'exp', 'log') within each
+// category (Unary, Binary etc.) must match the ordering of same fn
+// defined in UnaryFn, BinaryFn. This is to enable correct mapping
+// from this unified enum class to different category enums.
+def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
+  // Unary
+  I32EnumAttrCase<"exp", 0>,
+  I32EnumAttrCase<"log", 1>,
+  I32EnumAttrCase<"abs", 2>,
+  I32EnumAttrCase<"ceil", 3>,
+  I32EnumAttrCase<"floor", 4>,
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>,
+  I32EnumAttrCase<"erf", 12>,
+
+  // Binary
+
+  I32EnumAttrCase<"add", 13>,
+  I32EnumAttrCase<"sub", 14>,
+  I32EnumAttrCase<"mul", 15>,
+  I32EnumAttrCase<"div", 16>,
+  I32EnumAttrCase<"div_unsigned", 17>,
+  I32EnumAttrCase<"max_signed", 18>,
+  I32EnumAttrCase<"min_signed", 19>,
+  I32EnumAttrCase<"max_unsigned", 20>,
+  I32EnumAttrCase<"min_unsigned", 21>,
+  I32EnumAttrCase<"powf", 22>,
+
+  // Ternary
+  I32EnumAttrCase<"select", 23>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"exp", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..6d6ff7a5c7872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+//===----------------------------------------------------------------------===//
+
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+                   AttrSizedOperandSegments]> {
+  let summary = [{ Performs element-wise operation }];
+  let description = [{
+    Linalg op form which performs element-wise computation. The attribute
+    `func_type` describes the operation type (e.g. add, exp). The func_type
+    can be any valid unary, binary, or ternary operation.
+
+    Affine-maps for operands and result may be provided by the user. When
+    a user-defined indexing_map is not provided, identity map is inferred
+    for all operands. The default indexing maps are N identity-maps. ‘N’
+    depends on the arity of the elementwise op. The number of dims is
+    inferred from rank of the output type. In the case of default indexing
+    map, the input and output shapes must all match. Affine-map for operands
+    and result must be only projected permutations with no zero constants.
+
+    For element-wise iterator-type is always inferred as all ‘parallel’.
+    Iterator-type is needed for constructing this underlying structured op.
+    The number of dims of the iterator-type is inferred from the rank of
+    the result type.
+
+    Example:
+    Defining a unary linalg.elemwise with default indexing-map:
+
+      ```mlir
+      %exp = linalg.elemwise
+             func_type=#linalg.elemwise_fn<exp>
+             ins(%x : tensor<4x16x8xf32>)
+             outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+      ```
+
+    Defining a binary linalg.elemwise with user-defined indexing-map:
+
+    ```mlir
+    %add = linalg.elemwise
+            func_type=#linalg.elemwise_fn<add>
+            indexing_maps = [#transpose, #broadcast, #identity]
+            ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+            outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+    ```
+  }];
+
+  let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ElemwiseFnAttr:$func_type,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, ElemwiseOp::getRegionBuilder());
+      }]>
+    ];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+      /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+      static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+
+      /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+      /// or `default`. Default is identity-maps.
+      static bool hasDynamicIndexingMaps() { return true; }
+
+      /// Implements the block region builder for the eemwiseOp. This is called
+      /// by the 'fillStructuredOpRegion'.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+      ElemwiseFn getElemwiseFnVal() {
+        return getFuncType();
+      }
+
+      /// Infer dimensionality of the `iteration space` from the result type.
+      /// Useful when others means are not possible e.g. in case of absence of
+      /// user-provided indexing map.
+      unsigned getResultRank();
+
+      /// Elementwise op does not have to explicitly specify iterator type
+      /// as it is always 'parallel'. The number of 'parallel' loops is
+      /// inferred from other means (e.g. result tensor type).
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// The default indexing maps are N identity-maps. 'N' depends on the
+      /// arity of the elementwise op. The default case is when all input
+      /// output tensors are same rank and no transpose/broadcast is needed.
+      static SmallVector<AffineMap>
+      getDefaultIndexingMaps(unsigned N, unsigned numDims,
+                             MLIRContext *context);
+
+      /// Returns true if the user defined indexing maps are not equal to
+      /// the default (identity) map.
+      bool hasUserDefinedMaps();
+
+      /// destination passing style interface method.
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      std::string getLibraryCallName() {
+        return generateLibraryCallName(getOperation());
+      }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..c84220f5b4f2ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildElemwiseOp(OpBuilder &b, OperationState &state,
+                            std::optional<TypeRange> resultTensorTypes,
+                            ValueRange inputs, ValueRange outputs,
+                            ArrayRef<NamedAttribute> attributes,
+                            RegionBuilderFn regionBuilder) {
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3611,5 +3621,282 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ElemwiseOp - with support for affine map, func_type and comp_type
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct NAryCategoryAndFn {
+  // The enum category class {Unary, Binary, Ternary, ..}
+  ElemwiseNAryCategory category;
+
+  union NAryFn {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } fn;
+
+  ::llvm::StringRef stringifyCategory() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return "unary";
+    case ElemwiseNAryCategory::Binary:
+      return "binary";
+    case ElemwiseNAryCategory::Ternary:
+      return "ternary";
+    }
+    llvm_unreachable("unknown-category");
+  }
+
+  ::llvm::StringRef stringifyFn() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return stringifyUnaryFn(fn.unaryFn);
+    case ElemwiseNAryCategory::Binary:
+      return stringifyBinaryFn(fn.binaryFn);
+    case ElemwiseNAryCategory::Ternary:
+      return stringifyTernaryFn(fn.ternaryFn);
+    }
+    llvm_unreachable("unknown-fn");
+  }
+};
+
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+  switch (category) {
+  case ElemwiseNAryCategory::Unary:
+    return 1;
+  case ElemwiseNAryCategory::Binary:
+    return 2;
+  case ElemwiseNAryCategory::Ternary:
+    return 3;
+  }
+  llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+  constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+  constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+  constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+
+  int val = static_cast<int>(fn);
+  NAryCategoryAndFn result;
+  if (val <= lastUnary) {
+    result.category = ElemwiseNAryCategory::Unary;
+    result.fn.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val <= lastBinary) {
+    result.category = ElemwiseNAryCategory::Binary;
+    result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+    return result;
+  }
+  if (val > lastTernary) {
+    llvm_unreachable("unhandled ElemwiseFn");
+  }
+  result.category = ElemwiseNAryCategory::Ternary;
+  result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+  return result;
+}
+
+unsigned ElemwiseOp::getResultRank() {
+  auto output = getDpsInitOperand(0)->get();
+  auto shapedType = llvm::cast<ShapedType>(output.getType());
+  return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                   MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+}
+
+bool ElemwiseOp::hasUserDefinedMaps() {
+  auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+  auto arity = getArityFromCategory(category);
+
+  auto numDims = getResultRank();
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(arity + 1, numDims, this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Expect e.g. `func_type = #linalg.elemwise_fn<add>`
+  Attribute attr;
+  mlir::linalg::ElemwiseFn elemwiseFnVal;
+  if (parser.parseKeyword("func_type"))
+    return failure();
+  if (parser.parseEqual())
+    return failure();
+
+  if (succeeded(parser.parseAttribute(attr))) {
+    auto elemwiseFnAttr = dyn_cast<ElemwiseFnAttr>(attr);
+    if (!elemwiseFnAttr)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected ElemwiseFn attribute");
+    elemwiseFnVal = elemwiseFnAttr.getValue();
+  } else {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'func_type' attribute");
+  }
+  result.addAttribute("func_type",
+                      ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+  // Parse optional `indexing_maps`
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+
+      if (!isa<AffineMapAttr>(mapAttr))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+
+  // At this stage of parsing the only way to infer number of region
+  // args is through op kind, as input output tensors are not parsed yet.
+  auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+  auto arity = getArityFromCategory(arityAndCategory.category);
+  int numRegionArgs = arity + 1 /*output*/;
+
+  if (parseNamedStructuredOp(parser, result, numRegionArgs,
+                             ElemwiseOp::getRegionBuilder())) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "unable to parse elemwise op");
+  }
+
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    // We need to infer the `number of indexing maps` needed from the result
+    // type which is already parsed by now.
+    auto resultType = result.operands[result.operands.size() - 1].getType();
+    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+    if (!shapedType)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "return type needs to be shaped type");
+    auto numDims = shapedType.getRank();
+    indexingMapsAttr = llvm::map_to_vector(
+        ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+                                           parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return success();
+}
+
+void ElemwiseOp::print(OpAsmPrinter &p) {
+  p << " func_type=";
+  p.printAttribute(getFuncTypeAttr());
+
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "func_type",
+                                           "indexing_maps"};
+
+  auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+  auto arity = getArityFromCategory(category);
+
+  auto numDims = getResultRank();
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
+}
+
+LogicalResult ElemwiseOp::verify() {
+  // All necessary checks are done either by
+  // - EnumAttr (e.g. unknown func_type)
+  // - verifyStructuredOpInterface (incorrect map, sizes).
+  return success();
+}
+
+/// Implements the block region builder for the ElemwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  ElemwiseFn elemwiseFn;
+  for (auto attr : attrs) {
+    if (attr.getName() == b.getStringAttr("func_type")) {
+      auto funcTypeAttr = dyn_cast<ElemwiseFnAttr>(attr.getValue());
+      assert(funcTypeAttr && "func_type attribute incorrectly set");
+      elemwiseFn = funcTypeAttr.getValue();
+      break;
+    }
+  }
+
+  NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+  ElemwiseNAryCategory category = categoryAndFn.category;
+
+  unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+  assert(block.getNumArguments() == numBlockArgs &&
+         "Elemwise regionBuilder number of block args mismatch");
+
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+  Value result;
+
+  if (category == ElemwiseNAryCategory::Unary) {
+    result =
+        helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+
+  } else if (category == ElemwiseNAryCategory::Binary) {
+    result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+                                  block.getArgument(0), block.getArgument(1));
+  } else if (category == ElemwiseNAryCategory::Ternary) {
+    result = helper.buildTernaryFn(categoryAndFn.fn.ternaryFn,
+                                  block.getArgument(0), block.getArgument(1), block.getArgument(2));
+  } else
+  assert(false && "found unhandled category in elemwise print");
+
+  yields.push_back(result);
+  helper.yieldOutputs(yields);
+}
+
+LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void ElemwiseOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElemwiseOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir
new file mode 100644
index 00...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

Discussions are on-going on RFC especially about
comp_type and so that part is left open/unimplemented in this diff.


Patch is 33.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122753.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+51)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+130)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+287)
  • (added) mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir (+170)
  • (added) mlir/test/Dialect/Linalg/elemwise/invalid.mlir (+39)
  • (added) mlir/test/Dialect/Linalg/elemwise/round-trip.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..115eaebc6aff54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
+// Define the enum-type Elemwise func attribute.
+def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
   let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..5135e9cd4386ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -15,6 +15,57 @@
 
 include "mlir/IR/EnumAttr.td"
 
+// Define an `enum class : i32` to categorise element-wise op.
+def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
+  I32EnumAttrCase<"Unary", 0>,
+  I32EnumAttrCase<"Binary", 1>,
+  I32EnumAttrCase<"Ternary", 2>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
+// Define a unified `enum class : i32` for all element-wise options.
+// Note: The order of individual fn (e.g. 'exp', 'log') within each
+// category (Unary, Binary etc.) must match the ordering of same fn
+// defined in UnaryFn, BinaryFn. This is to enable correct mapping
+// from this unified enum class to different category enums.
+def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
+  // Unary
+  I32EnumAttrCase<"exp", 0>,
+  I32EnumAttrCase<"log", 1>,
+  I32EnumAttrCase<"abs", 2>,
+  I32EnumAttrCase<"ceil", 3>,
+  I32EnumAttrCase<"floor", 4>,
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>,
+  I32EnumAttrCase<"erf", 12>,
+
+  // Binary
+
+  I32EnumAttrCase<"add", 13>,
+  I32EnumAttrCase<"sub", 14>,
+  I32EnumAttrCase<"mul", 15>,
+  I32EnumAttrCase<"div", 16>,
+  I32EnumAttrCase<"div_unsigned", 17>,
+  I32EnumAttrCase<"max_signed", 18>,
+  I32EnumAttrCase<"min_signed", 19>,
+  I32EnumAttrCase<"max_unsigned", 20>,
+  I32EnumAttrCase<"min_unsigned", 21>,
+  I32EnumAttrCase<"powf", 22>,
+
+  // Ternary
+  I32EnumAttrCase<"select", 23>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"exp", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..6d6ff7a5c7872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+//===----------------------------------------------------------------------===//
+
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+                   AttrSizedOperandSegments]> {
+  let summary = [{ Performs element-wise operation }];
+  let description = [{
+    Linalg op form which performs element-wise computation. The attribute
+    `func_type` describes the operation type (e.g. add, exp). The func_type
+    can be any valid unary, binary, or ternary operation.
+
+    Affine-maps for operands and result may be provided by the user. When
+    a user-defined indexing_map is not provided, identity map is inferred
+    for all operands. The default indexing maps are N identity-maps. ‘N’
+    depends on the arity of the elementwise op. The number of dims is
+    inferred from rank of the output type. In the case of default indexing
+    map, the input and output shapes must all match. Affine-map for operands
+    and result must be only projected permutations with no zero constants.
+
+    For element-wise iterator-type is always inferred as all ‘parallel’.
+    Iterator-type is needed for constructing this underlying structured op.
+    The number of dims of the iterator-type is inferred from the rank of
+    the result type.
+
+    Example:
+    Defining a unary linalg.elemwise with default indexing-map:
+
+      ```mlir
+      %exp = linalg.elemwise
+             func_type=#linalg.elemwise_fn<exp>
+             ins(%x : tensor<4x16x8xf32>)
+             outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+      ```
+
+    Defining a binary linalg.elemwise with user-defined indexing-map:
+
+    ```mlir
+    %add = linalg.elemwise
+            func_type=#linalg.elemwise_fn<add>
+            indexing_maps = [#transpose, #broadcast, #identity]
+            ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+            outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+    ```
+  }];
+
+  let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ElemwiseFnAttr:$func_type,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, ElemwiseOp::getRegionBuilder());
+      }]>
+    ];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+      /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+      static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+
+      /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+      /// or `default`. Default is identity-maps.
+      static bool hasDynamicIndexingMaps() { return true; }
+
+      /// Implements the block region builder for the eemwiseOp. This is called
+      /// by the 'fillStructuredOpRegion'.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+      ElemwiseFn getElemwiseFnVal() {
+        return getFuncType();
+      }
+
+      /// Infer dimensionality of the `iteration space` from the result type.
+      /// Useful when others means are not possible e.g. in case of absence of
+      /// user-provided indexing map.
+      unsigned getResultRank();
+
+      /// Elementwise op does not have to explicitly specify iterator type
+      /// as it is always 'parallel'. The number of 'parallel' loops is
+      /// inferred from other means (e.g. result tensor type).
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// The default indexing maps are N identity-maps. 'N' depends on the
+      /// arity of the elementwise op. The default case is when all input
+      /// output tensors are same rank and no transpose/broadcast is needed.
+      static SmallVector<AffineMap>
+      getDefaultIndexingMaps(unsigned N, unsigned numDims,
+                             MLIRContext *context);
+
+      /// Returns true if the user defined indexing maps are not equal to
+      /// the default (identity) map.
+      bool hasUserDefinedMaps();
+
+      /// destination passing style interface method.
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      std::string getLibraryCallName() {
+        return generateLibraryCallName(getOperation());
+      }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..c84220f5b4f2ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildElemwiseOp(OpBuilder &b, OperationState &state,
+                            std::optional<TypeRange> resultTensorTypes,
+                            ValueRange inputs, ValueRange outputs,
+                            ArrayRef<NamedAttribute> attributes,
+                            RegionBuilderFn regionBuilder) {
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3611,5 +3621,282 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ElemwiseOp - with support for affine map, func_type and comp_type
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct NAryCategoryAndFn {
+  // The enum category class {Unary, Binary, Ternary, ..}
+  ElemwiseNAryCategory category;
+
+  union NAryFn {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } fn;
+
+  ::llvm::StringRef stringifyCategory() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return "unary";
+    case ElemwiseNAryCategory::Binary:
+      return "binary";
+    case ElemwiseNAryCategory::Ternary:
+      return "ternary";
+    }
+    llvm_unreachable("unknown-category");
+  }
+
+  ::llvm::StringRef stringifyFn() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return stringifyUnaryFn(fn.unaryFn);
+    case ElemwiseNAryCategory::Binary:
+      return stringifyBinaryFn(fn.binaryFn);
+    case ElemwiseNAryCategory::Ternary:
+      return stringifyTernaryFn(fn.ternaryFn);
+    }
+    llvm_unreachable("unknown-fn");
+  }
+};
+
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+  switch (category) {
+  case ElemwiseNAryCategory::Unary:
+    return 1;
+  case ElemwiseNAryCategory::Binary:
+    return 2;
+  case ElemwiseNAryCategory::Ternary:
+    return 3;
+  }
+  llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+  constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+  constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+  constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+
+  int val = static_cast<int>(fn);
+  NAryCategoryAndFn result;
+  if (val <= lastUnary) {
+    result.category = ElemwiseNAryCategory::Unary;
+    result.fn.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val <= lastBinary) {
+    result.category = ElemwiseNAryCategory::Binary;
+    result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+    return result;
+  }
+  if (val > lastTernary) {
+    llvm_unreachable("unhandled ElemwiseFn");
+  }
+  result.category = ElemwiseNAryCategory::Ternary;
+  result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+  return result;
+}
+
+unsigned ElemwiseOp::getResultRank() {
+  auto output = getDpsInitOperand(0)->get();
+  auto shapedType = llvm::cast<ShapedType>(output.getType());
+  return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                   MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+}
+
+bool ElemwiseOp::hasUserDefinedMaps() {
+  auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+  auto arity = getArityFromCategory(category);
+
+  auto numDims = getResultRank();
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(arity + 1, numDims, this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Expect e.g. `func_type = #linalg.elemwise_fn<add>`
+  Attribute attr;
+  mlir::linalg::ElemwiseFn elemwiseFnVal;
+  if (parser.parseKeyword("func_type"))
+    return failure();
+  if (parser.parseEqual())
+    return failure();
+
+  if (succeeded(parser.parseAttribute(attr))) {
+    auto elemwiseFnAttr = dyn_cast<ElemwiseFnAttr>(attr);
+    if (!elemwiseFnAttr)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected ElemwiseFn attribute");
+    elemwiseFnVal = elemwiseFnAttr.getValue();
+  } else {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'func_type' attribute");
+  }
+  result.addAttribute("func_type",
+                      ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+  // Parse optional `indexing_maps`
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+
+      if (!isa<AffineMapAttr>(mapAttr))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+
+  // At this stage of parsing the only way to infer number of region
+  // args is through op kind, as input output tensors are not parsed yet.
+  auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+  auto arity = getArityFromCategory(arityAndCategory.category);
+  int numRegionArgs = arity + 1 /*output*/;
+
+  if (parseNamedStructuredOp(parser, result, numRegionArgs,
+                             ElemwiseOp::getRegionBuilder())) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "unable to parse elemwise op");
+  }
+
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    // We need to infer the `number of indexing maps` needed from the result
+    // type which is already parsed by now.
+    auto resultType = result.operands[result.operands.size() - 1].getType();
+    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+    if (!shapedType)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "return type needs to be shaped type");
+    auto numDims = shapedType.getRank();
+    indexingMapsAttr = llvm::map_to_vector(
+        ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+                                           parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return success();
+}
+
+void ElemwiseOp::print(OpAsmPrinter &p) {
+  p << " func_type=";
+  p.printAttribute(getFuncTypeAttr());
+
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "func_type",
+                                           "indexing_maps"};
+
+  auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+  auto arity = getArityFromCategory(category);
+
+  auto numDims = getResultRank();
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
+}
+
+LogicalResult ElemwiseOp::verify() {
+  // All necessary checks are done either by
+  // - EnumAttr (e.g. unknown func_type)
+  // - verifyStructuredOpInterface (incorrect map, sizes).
+  return success();
+}
+
+/// Implements the block region builder for the ElemwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  ElemwiseFn elemwiseFn;
+  for (auto attr : attrs) {
+    if (attr.getName() == b.getStringAttr("func_type")) {
+      auto funcTypeAttr = dyn_cast<ElemwiseFnAttr>(attr.getValue());
+      assert(funcTypeAttr && "func_type attribute incorrectly set");
+      elemwiseFn = funcTypeAttr.getValue();
+      break;
+    }
+  }
+
+  NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+  ElemwiseNAryCategory category = categoryAndFn.category;
+
+  unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+  assert(block.getNumArguments() == numBlockArgs &&
+         "Elemwise regionBuilder number of block args mismatch");
+
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+  Value result;
+
+  if (category == ElemwiseNAryCategory::Unary) {
+    result =
+        helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+
+  } else if (category == ElemwiseNAryCategory::Binary) {
+    result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+                                  block.getArgument(0), block.getArgument(1));
+  } else if (category == ElemwiseNAryCategory::Ternary) {
+    result = helper.buildTernaryFn(categoryAndFn.fn.ternaryFn,
+                                  block.getArgument(0), block.getArgument(1), block.getArgument(2));
+  } else
+  assert(false && "found unhandled category in elemwise print");
+
+  yields.push_back(result);
+  helper.yieldOutputs(yields);
+}
+
+LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void ElemwiseOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElemwiseOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir
new file mode 100644
index 00...
[truncated]

Copy link

github-actions bot commented Jan 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the first pass, already looks really good 👍

auto arity = getArityFromCategory(category);

auto numDims = getResultRank();
SmallVector<AffineMap, 3> defaultMaps =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it can be just SmallVector<AffineMap> as size changes with arity

I32EnumAttrCase<"erf", 12>,

// Binary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary extra new line

/// or `default`. Default is identity-maps.
static bool hasDynamicIndexingMaps() { return true; }

/// Implements the block region builder for the eemwiseOp. This is called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Implements the block region builder for the eemwiseOp. This is called
/// Implements the block region builder for the elemwiseOp. This is called

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt look at all the details, but left a few comments. Structure overall seems fine to me.

// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
//===----------------------------------------------------------------------===//

def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd rather just use ElementWiseOp and element_wise . But totally my preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for full spelling, feels more intuitive
nit to nit: I'd go with a single word version ElementwiseOp and elementwise, it's also more consistent with spelling across repo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd rather just use ElementWiseOp and element_wise . But totally my preference.

I too agree. I just wasnt sure what you folks would prefer (longer IR and precise, or shorter and guess-it)

@@ -15,6 +15,57 @@

include "mlir/IR/EnumAttr.td"

// Define an `enum class : i32` to categorise element-wise op.
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be frowned upon, but I dont think we need two different sets of enums. We could just use a range of enums in ElemwiseFn below to denote unary/binary and ternary functions....

return defaultMaps != explicitMaps;
}

ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as what I left on the matmul op PR. I think this could be done with assembly format with just custom parser/printer functions.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @javedabsar1 , excited to see this materialising 🙇🏻 🙏🏻

Some minor comments in tests while I go over the other details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new directory?

Copy link
Contributor Author

@javedabsar1 javedabsar1 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because i reckon the number of tests will increases over time as elementwise develops further.

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @fewer_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
func.func @fewer_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @incorrect_transpose_map(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
func.func @incorrect_transpose_map(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
func.func @identitiy_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {

// CHECK: @multiple_dims(%[[A]]: tensor<1x2x3x4x5xi32>, %[[B:.+]]: tensor<1x2x3x4x5xi32>, %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>) outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
//
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Please use consistent naming

Suggested change
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
func.func @multiple_dims(%A : tensor<1x2x3x4x5xi32>, %B : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {

Comment on lines +67 to +71
%r = linalg.elemwise func_type=#linalg.elemwise_fn<mul>
indexing_maps = [#map, #map, #map]
ins(%arg0, %arg1: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
outs(%arg2: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indentation is different to the rest of the file. Please unify. Personally I prefer this "reduced" indentation.

// CHECK: @binary_mul(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To make a clearer distinction from multiple_dims

Suggested change
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
func.func @binary_mu_2Dl(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

er, I think you mean binary_mul_2D 😆

// CHECK: @multiple_dims(%[[A]]: tensor<1x2x3x4x5xi32>, %[[B:.+]]: tensor<1x2x3x4x5xi32>, %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>) outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
//
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To make a clearer distinction from @binary_mul

Suggested change
func.func @multiple_dims(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
func.func @binary_mul_5D(%arg0 : tensor<1x2x3x4x5xi32>, %arg1 : tensor<1x2x3x4x5xi32>, %arg2 : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you make sure that the test function naming is consistent in this file? From what I can see, it goes like this:

  • @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}

I'm not sure how to handle on_tensor and on_memref 😅

@banach-space
Copy link
Contributor

High-level question for everyone: Please feel free to shut this thread down if you think this isn't worth discussing (to avoid bike-shedding).

To me, func_type feels somewhat confusing or counter-intuitive—does it refer to a programming language type? For example:

func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
  %r = linalg.elemwise
               func_type=#linalg.elemwise_fn<exp>
               ins(%A : tensor<8x16x32xf32>)
               outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
  return %r : tensor<8x16x32xf32>
}

Would something like math_op or body_op be clearer (or even op)?

Ultimately, it seems to describe "a mathematical operation inside the Linalg Op body." Do these alternatives make sense, or do you feel func_type is already intuitive? WDYT?

@adam-smnk
Copy link
Contributor

Would something like math_op or body_op be clearer (or even op)?

To me, they all (together with func_type ) sound arbitrary.
kind is another option that seems somewhat popular like in vector.contract.

Nothing better comes to my mind but I wouldn't mind if we could align the attr name with something already existing in the wild. Maybe even sth from PyTorch or IREE? Such that most people could recognize it more intuitively.

@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Lets stick with the existing convention.

Suggested change
// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
// Op definition for ElemwiseOp

let description = [{
Linalg op form which performs element-wise computation. The attribute
`func_type` describes the operation type (e.g. add, exp). The func_type
can be any valid unary, binary, or ternary operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we defining "any valid"? I would either list all or skip this altogether.

Comment on lines +568 to +569
for all operands. The default indexing maps are N identity-maps. ‘N’
depends on the arity of the elementwise op. The number of dims is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't N simply the rank of the output? And doesn't "arity" mean the number of function args? Is that relevant here?

Comment on lines +570 to +571
inferred from rank of the output type. In the case of default indexing
map, the input and output shapes must all match. Affine-map for operands
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inferred from rank of the output type. In the case of default indexing
map, the input and output shapes must all match. Affine-map for operands
inferred from rank of the output type. In the case of default indexing
maps, the input and output shapes must match. Affine-maps for operands

depends on the arity of the elementwise op. The number of dims is
inferred from rank of the output type. In the case of default indexing
map, the input and output shapes must all match. Affine-map for operands
and result must be only projected permutations with no zero constants.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and result must be only projected permutations with no zero constants.
and result must be projected permutations with no zero constants.

/// the default (identity) map.
bool hasUserDefinedMaps();

/// destination passing style interface method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// destination passing style interface method.
/// Destination passing style interface method.

Comment on lines +651 to +653
/// Infer dimensionality of the `iteration space` from the result type.
/// Useful when others means are not possible e.g. in case of absence of
/// user-provided indexing map.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this simply return the rank of the result? There's isn't much "inference" involved 😅 As per your implementation, return shapedType.getRank();.

Comment on lines +656 to +658
/// Elementwise op does not have to explicitly specify iterator type
/// as it is always 'parallel'. The number of 'parallel' loops is
/// inferred from other means (e.g. result tensor type).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is already included the general Op description above:

  /// Elementwise op does not have to explicitly specify iterator type
/// as it is always 'parallel'

Please only document what getIteratorTypesArray does.

Comment on lines +661 to +663
/// The default indexing maps are N identity-maps. 'N' depends on the
/// arity of the elementwise op. The default case is when all input
/// output tensors are same rank and no transpose/broadcast is needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to refine this, to me "N identity-maps" could mean "identity-maps with N dims". Perhaps,

The default indexing maps are identities. There will be N such maps, where N is the arity of the Op.

`func_type` describes the operation type (e.g. add, exp). The func_type
can be any valid unary, binary, or ternary operation.

Affine-maps for operands and result may be provided by the user. When
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Affine-maps for operands and result may be provided by the user.

IIUC, maps are required when transposing or broadcasting?

@javedabsar1
Copy link
Contributor Author

javedabsar1 commented Jan 15, 2025

e even sth from PyTorch or IREE? Such that most people could recognize it more intuitive

Whatever we choose to call it, we need to keep in mind that we will have two attributes that should be consistently named : func_type/func_kind/opKind/computationKind.... (e.g. add) and computation_type/ccumulator_type ( e.g. i32) .
Also should be consistent with whatever other new ops ... contract/pooling/... etc call and need .

How about comp_kind and comp_type ? Maybe just vote with thums up/down, or suggest better ones.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have tests that show binary ops with single arguments / unary ops with more than one being rejected?

// CHECK: @binary_mul(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elemwise func_type=#linalg.elemwise_fn<mul> ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @binary_mul(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

er, I think you mean binary_mul_2D 😆

@banach-space
Copy link
Contributor

e even sth from PyTorch or IREE? Such that most people could recognize it more intuitive

Whatever we choose to call it, we need to keep in mind that we will have two attributes that should be consistently named : func_type/func_kind/opKind/computationKind.... (e.g. add) and computation_type/ccumulator_type ( e.g. i32) .

IIUC, we don't need "compute type" just now? Shall we cross that bridge when we get there and in the meantime prioritise consistency and use kind (@adam-smnk - I like your suggestion!)?

@@ -15,6 +15,57 @@

include "mlir/IR/EnumAttr.td"

// Define an `enum class : i32` to categorise element-wise op.
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
I32EnumAttrCase<"Unary", 0>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need separate enums here? Can we juse use the enum below to classify what is a unary/binary or ternary function? (I think this is a repeat comment from above)

);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a region on the op?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to cross-pollinate the Linalg discussions.
It looks like regions are required by Linalg's interface. Thus, all ops currently have and should have a region.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of other comments. I think it would be useful to not have regions on the operations since they seem unnecessary. Apart from that this looks good to me.

@javedabsar1
Copy link
Contributor Author

created a new PR -- #124661

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants