Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Extend elementwise #124661

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Jan 28, 2025

Implements Linalg elemwise named-op following the proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

I address the comments made in previous PR #122753 but closed that and opened this one as I messed the merge and could not recover.

Just to recap, the main changes I made :

  • instead of duplicating Unary/Binary enums to create a new unified one, I have now added some tablegen functions 'join'/'flatten' to derive unified list. A simple list concat would not work as enums values would clash,
  • Fixed tests naming and added more tests as requested.
  • Addressed typos/nitpicks.
  • changed 'comp_type' to 'kind' (e.g. add, sub) by popular request.

Some comments that that i did not (could not) do code change:

  • default print will not work as it would print default maps.
  • region is needed to lower to linalg.generic.

@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Implements Linalg elemwise named-op following the proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

I address the comments made in previous PR #122753 but closed that and opened this one as I messed the merge and could not recover.

Just to recap, the main changes I made :

  • instead of duplicating Unary/Binary enums to create a new unified one, I have now added some tablegen functions 'join'/'flatten' to derive unified list. A simple list concat would not work as enums values would clash,
  • Fix tests naming and added more tests as requested.
  • Addressed typos/nitpicks.
  • change 'comp_type' to 'kind' (e.g. add, sub) by popular request
  • default print will not do as it would print default maps.
  • region is needed to lower to linalg.generic.

Patch is 33.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124661.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+6)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+44)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+116)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+263)
  • (added) mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir (+157)
  • (added) mlir/test/Dialect/Linalg/element_wise/invalid.mlir (+54)
  • (added) mlir/test/Dialect/Linalg/element_wise/round-trip.mlir (+88)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..00e3633610ccb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
+// Define the attribute enums matching elementwise op function (e.g., add).
+def ElementwiseFnAttr : EnumAttr<Linalg_Dialect,
+                                 ElementwiseFn, "elementwise_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
   let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..c41b5418357514 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -55,6 +55,50 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
 }
+
+// Join two I32EnumAttrCase lists. This joining takes care that the
+// 'int enum values' in the combined list do not overlap. It does this
+// by adding to each element of second list the offset '!size(a)'.
+class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
+                                  list<I32EnumAttrCase> b> {
+  int aSize = !size(a);
+  list<I32EnumAttrCase> result =
+             !foldl(a, b, acc, var,
+                    acc # [I32EnumAttrCase<var.symbol,
+                                           !add(var.value, aSize)
+                                           >]);
+}
+
+// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
+// The flattening (via call to 'join') ensures no overlap in enum values.
+class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
+  list<I32EnumAttrCase> result =
+             !foldl([]<I32EnumAttrCase>, l, acc, var,
+                    JoinTwoI32EnumAttrCaseList<acc, var>.result);
+}
+
+// Define a unified `enum class : i32` for all element-wise op functions.
+def ElementwiseFn :
+            I32EnumAttr<"ElementwiseFn",
+                        "",
+                        ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
+                                                  BinaryFn.enumerants,
+                                                  TernaryFn.enumerants]>.result
+                      > {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` to categorise elementwise ops.
+def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
+  I32EnumAttrCase<"Unary", 0>,
+  I32EnumAttrCase<"Binary", 1>,
+  I32EnumAttrCase<"Ternary", 2>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
 def TypeFn : I32EnumAttr<"TypeFn", "", [
   I32EnumAttrCase<"cast_signed", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..2d82eef41c2f2f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,122 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ElementwiseOp
+//===----------------------------------------------------------------------===//
+def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
+                   AttrSizedOperandSegments]> {
+  let summary = [{ Performs element-wise operation }];
+  let description = [{
+    Linalg op form which performs element-wise computation.
+
+    The attribute `kind` describes the operation (e.g. add, exp). The operation
+    kind can be any elementwise nary (e.g. unary, binary) operation.
+
+    Affine-maps for operands and result are reuired to be provided by the user
+    when transpose and/or broadcast is needed on any operand. When a map is not
+    provided, default identity maps are inferred for each operand. The number
+    of dims in each of the identity maps is equal to the rank of the output type.
+    In the case of default indexing map, all input and output shapes must match.
+    User-defined Affine-map for operands and result must only be projected
+    permutations with no zero constants.
+
+    For elementwise, iterator-types are always 'all parallel’.
+    Iterator-types are needed for constructing the underlying structured op.
+    The number of dims of the iterator-types are inferred from the rank of
+    the result type.
+
+    Example:
+
+    Defining a unary linalg.elemwise with default indexing-map:
+      ```mlir
+      %exp = linalg.elemwise
+             kind=#linalg.elemwise_fn<exp>
+             ins(%x : tensor<4x16x8xf32>)
+             outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+      ```
+
+    Defining a binary linalg.elemwise with user-defined indexing-map:
+    ```mlir
+    %add = linalg.elemwise
+            kind=#linalg.elemwise_fn<add>
+            indexing_maps = [#transpose, #broadcast, #identity]
+            ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+            outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+    ```
+  }];
+
+  let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ElementwiseFnAttr:$kind,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, ElementwiseOp::getRegionBuilder());
+      }]>
+    ];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+      /// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`,
+      /// corresponding to the given fn, e.g. `ElementwiseFn::exp`
+      static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);
+
+      /// Both user-specified and default indexing map will always depend on
+      ///  the current Op instance.
+      static bool hasDynamicIndexingMaps() { return true; }
+
+      /// Implements the block region builder for the elementwiseOp. This is
+      /// called by the 'fillStructuredOpRegion'.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns rank of the result tensor/memref. Useful for knowing
+      /// the dimensionality of the iteration space when others means
+      /// are not possible e.g. absence of user-provided indexing map.
+      unsigned getResultRank();
+
+      /// Returns N 'parallel' iterator types where N is rank of result.
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// The default indexing maps are identities.
+      /// There will be N such maps, where N is the arity of the Op.
+      static SmallVector<AffineMap>
+      getDefaultIndexingMaps(unsigned N, unsigned numDims,
+                             MLIRContext *context);
+
+      /// Destination passing style interface method.
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      std::string getLibraryCallName() {
+        return generateLibraryCallName(getOperation());
+      }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..e1947a864d4d0d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildElementwiseOp(OpBuilder &b, OperationState &state,
+                               std::optional<TypeRange> resultTensorTypes,
+                               ValueRange inputs, ValueRange outputs,
+                               ArrayRef<NamedAttribute> attributes,
+                               RegionBuilderFn regionBuilder) {
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3611,5 +3620,259 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ElementwiseOp
+//===----------------------------------------------------------------------===//
+//
+namespace {
+
+struct NAryCategoryAndFn {
+  // The enum category class {Unary, Binary, Ternary, ..}
+  ElementwiseNAryCategory category;
+
+  union NAryFn {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } fn;
+
+  ::llvm::StringRef stringifyCategory() {
+    return stringifyElementwiseNAryCategory(category);
+  }
+
+  ::llvm::StringRef stringifyFn() {
+    switch (category) {
+    case ElementwiseNAryCategory::Unary:
+      return stringifyUnaryFn(fn.unaryFn);
+    case ElementwiseNAryCategory::Binary:
+      return stringifyBinaryFn(fn.binaryFn);
+    case ElementwiseNAryCategory::Ternary:
+      return stringifyTernaryFn(fn.ternaryFn);
+    }
+    llvm_unreachable("unknown-fn");
+  }
+};
+
+unsigned getArityFromCategory(ElementwiseNAryCategory category) {
+  switch (category) {
+  case ElementwiseNAryCategory::Unary:
+    return 1;
+  case ElementwiseNAryCategory::Binary:
+    return 2;
+  case ElementwiseNAryCategory::Ternary:
+    return 3;
+  }
+  llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
+  constexpr int lastUnary = static_cast<int>(ElementwiseFn::erf);
+  constexpr int lastBinary = static_cast<int>(ElementwiseFn::powf);
+  constexpr int lastTernary = static_cast<int>(ElementwiseFn::select);
+
+  int val = static_cast<int>(fn);
+  NAryCategoryAndFn result;
+
+  if (val <= lastUnary) {
+    result.category = ElementwiseNAryCategory::Unary;
+    result.fn.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val <= lastBinary) {
+    result.category = ElementwiseNAryCategory::Binary;
+    result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+    return result;
+  }
+  if (val > lastTernary) {
+    llvm_unreachable("unhandled ElementwiseFn");
+  }
+  result.category = ElementwiseNAryCategory::Ternary;
+  result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+  return result;
+}
+
+unsigned ElementwiseOp::getResultRank() {
+  auto output = getDpsInitOperand(0)->get();
+  auto shapedType = llvm::cast<ShapedType>(output.getType());
+  return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                      MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+}
+
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Expect e.g. `kind = #linalg.elemwise_fn<add>`
+  Attribute attr;
+  mlir::linalg::ElementwiseFn elemwiseFnVal;
+  if (parser.parseKeyword("kind"))
+    return failure();
+  if (parser.parseEqual())
+    return failure();
+  if (succeeded(parser.parseAttribute(attr))) {
+    auto elemwiseFnAttr = dyn_cast<ElementwiseFnAttr>(attr);
+    if (!elemwiseFnAttr)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected ElementwiseFn attribute");
+    elemwiseFnVal = elemwiseFnAttr.getValue();
+  } else {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected operation 'kind' attribute");
+  }
+  result.addAttribute(
+      "kind", ElementwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+  // Parse optional `indexing_maps`
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+      return failure();
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      indexingMapsAttr.push_back(mapAttr);
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+    if (parser.parseRSquare())
+      return failure();
+  }
+
+  // At this stage of parsing the only way to infer number of region
+  // args is through op kind, as input output tensors are not parsed yet.
+  auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+  auto arity = getArityFromCategory(arityAndCategory.category);
+  int numRegionArgs = arity + 1 /*output*/;
+  if (parseNamedStructuredOp(parser, result, numRegionArgs,
+                             ElementwiseOp::getRegionBuilder())) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "unable to parse elemwise op");
+  }
+
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    // We need to infer the `number of indexing maps` needed from the result
+    // type which is already parsed by now.
+    auto resultType = result.operands[result.operands.size() - 1].getType();
+    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+    if (!shapedType)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "return type needs to be shaped type");
+    auto numDims = shapedType.getRank();
+    indexingMapsAttr = llvm::map_to_vector(
+        ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+                                              parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return success();
+}
+
+void ElementwiseOp::print(OpAsmPrinter &p) {
+  p << " kind=";
+  p.printAttribute(getKindAttr());
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+                                           "indexing_maps"};
+  auto category = getNAryCategoryAndFn(getKind()).category;
+  auto arity = getArityFromCategory(category);
+  auto numDims = getResultRank();
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
+}
+
+LogicalResult ElementwiseOp::verify() {
+  // All necessary checks are done either by
+  // - EnumAttr (e.g. unknown operation kind)
+  // - verifyStructuredOpInterface (incorrect map, sizes).
+  return success();
+}
+
+/// Implements the block region builder for the ElementwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  ElementwiseFn elemwiseFn;
+  for (auto attr : attrs) {
+    if (attr.getName() == b.getStringAttr("kind")) {
+      auto funcTypeAttr = dyn_cast<ElementwiseFnAttr>(attr.getValue());
+      assert(funcTypeAttr && "op kind attribute incorrectly set");
+      elemwiseFn = funcTypeAttr.getValue();
+      break;
+    }
+  }
+
+  NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+  ElementwiseNAryCategory category = categoryAndFn.category;
+  unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+  assert(block.getNumArguments() == numBlockArgs &&
+         "Elementwise regionBuilder number of block args mismatch");
+
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+  Value result;
+
+  if (category == ElementwiseNAryCategory::Unary) {
+    result =
+        helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+  } else if (category == ElementwiseNAryCategory::Binary) {
+    result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+                                  block.getArgument(0), block.getArgument(1));
+  } else if (category == ElementwiseNAryCategory::Ternary) {
+    result =
+        helper.buildTernaryFn(categoryAndFn.fn.ternaryFn, block.getArgument(0),
+                              block.getArgument(1), block.getArgument(2));
+  } else
+    assert(false && "found unhandled category in elemwise print");
+
+  yields.push_back(result);
+  helper.yieldOutputs(yields);
+}
+
+LogicalResult ElementwiseOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ElementwiseOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElementwiseOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
new file mode 100644
index 00000000000000..2466a77acc236e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK:   %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
+// CHECK:   linalg.yield %[[EXP]] : f32
+//
+func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_fn<exp>
+               ins(%A : tensor<8x16x32xf32>)
+               outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_transpose_br...
[truncated]

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments, but overall looking good. Thanks!

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Outdated Show resolved Hide resolved
User-defined Affine-map for operands and result must only be projected
permutations with no zero constants.

For elementwise, iterator-types are always 'all parallel’.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For elementwise, iterator-types are always 'all parallel’.
For elementwise, iterator-types are always all `parallel’.

}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to accept AnyType on the input and not AnyShaped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. changed to AnyShaped.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment about this below: the primary reason to allow AnyType is so that a scalar operand can be broadcast to the full output shape.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated Show resolved Hide resolved
// Join two I32EnumAttrCase lists. This joining takes care that the
// 'int enum values' in the combined list do not overlap. It does this
// by adding to each element of second list the offset '!size(a)'.
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud, perhaps it would be easier to use bit patterns instead of joining enum lists. We won't have more than 20 operations per Category, so:

  • Unary: (op | (0xFF << 1))
  • Binary: (op | (0xFF << 2))
  • Ternary: (op | (0xFF << 3))

And set the enums above like:

  • I32EnumAttrCase<"log", (1 << 1)>
  • I32EnumAttrCase<"sub", (1 << 2)>
  • I32EnumAttrCase<"select", (1 << 3)>

etc?

Then you don't need all the complex sequential logic in the parser/verifiers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need anymore.

Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the alternative join-based approach, @javedabsar1 - It's quite impressive what can done with TableGen! (For better or worse, TableGen is a programming language on its own.)

I would still think an approach like @rengolin's would lead to simpler C++. The scheme I have in mind is to just shift the arity 30 bits, e.g. I32EnumAttrCase<"abs", 2> becomes I32EnumAttrCase<"abs", 2 + (1 << 30)>,
I32EnumAttrCase<"div", 3 + (2 << 30)>
I32EnumAttrCase<"select", 0 + (3 << 30)>. This way the arity can be retrieved by just shifting right 30 bits (e.g. derivedEnumVal >> 30) and to obtain the original op code you just do derivedEnumVal & ((1 << 30) - 1).

Three nested !lfolds should now suffice to derive all the derived enum cases and ElementwiseFnLimits could go and NAryCategoryAndFn could go or be simplified.

What do you think? (If you could just state the benefits of your approach, that would also be fine OFC.)

@rengolin rengolin requested a review from rolfmorel February 3, 2025 12:27
@rengolin
Copy link
Member

rengolin commented Feb 3, 2025

Thanks @javedabsar1 this looks good to me now! I'd like @banach-space and @MaheshRavishankar to have a look, as they participated in the design discussions, too. Also, @rolfmorel, @shahidact and @adam-smnk who have been active on the area.

}

// Define an `enum class : i32` to categorise elementwise ops.
def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this just be ElementwiseArity? The concept of NAryCategory is exactly arity, right?

Linalg op form which performs element-wise computation.

The attribute `kind` describes the operation (e.g. add, exp). The operation
kind can be any elementwise nary (e.g. unary, binary) operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rephrase to use arity instead of nary

kind can be any elementwise nary (e.g. unary, binary) operation.

Affine-maps for operands and result are required to be provided by the user
when transpose and/or broadcast is needed on any operand. When a map is not
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: when a transpose ...

User-defined affine-map for operands and result must only be projected
permutations with no zero constants.

For elementwise, iterator-types are always `all parallel`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: all parallel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, skip For elementwise. This is duplicating info.

}];

let arguments = (ins
Variadic<AnyShaped>:$inputs,
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow broadcasting a scalar (e.g. as is supported by generic, contract and matmul)? If so, this AnyShaped should probably be AnyType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should allow that

(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we inline buildElementwiseOp here? As buildElementwiseOp is just forwarding args to buildStructuredOp, it seems the separate function (not close to any other elementwise op code) is unnecessary.

let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`,
/// corresponding to the given fn, e.g. `ElementwiseFn::exp`
static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ElementwiseNAryCategory just encodes the arity, do we need a separate enum for this? Would just an unsigned int suffice? (The thing we would lose is that the enum encodes/enforce arities are restricted to unary, binary and ternary - maybe that's sufficient reason to keep it.)

static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);

/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: double space

unsigned getResultRank();

/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the availability of getResultRank(), sounds like the definition could be here in the .td file.

SmallVector<utils::IteratorType> getIteratorTypesArray();

/// The default indexing maps are identities.
/// There will be N such maps, where N is the arity of the Op.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should "N" maybe be "arity"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and numDims could be rank, right?

}

SmallVector<AffineMap>
ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consistent arg names w.r.t. declaration

}

SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
auto rank = getResultRank();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expand auto

}

unsigned ElementwiseOp::getResultRank() {
auto output = getDpsInitOperand(0)->get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expand auto

Comment on lines +571 to +572
User-defined affine-map for operands and result must only be projected
permutations with no zero constants.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for this restriction? This restriction does not seem relevant here. The op is still elementwise, no matter how you define the the iteration map.

Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question! Are there any considerations arising from the "linalg tree" for why the indexing_maps should be projected permutations?

permutations with no zero constants.

For elementwise, iterator-types are always `all parallel`.
Iterator-types are needed for constructing the underlying structured op.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is relevant here. This is somehow implying the op stores iterator types on the operation. This is confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on this.

}];

let arguments = (ins
Variadic<AnyShaped>:$inputs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should allow that

Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseFnAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this should be DefaultValuedOptionalAttr. "{}" can be invalid in many cases. Instead, we should just have a builder for having a derived default value of the attribute.

mlir::linalg::ElementwiseFn elemwiseFnVal;
if (parser.parseKeyword("kind"))
return failure();
if (parser.parseEqual())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly terser:

if (parser.parseKeyword("kind") || parser.parseEqual())
    return failure();

SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above: can use ||

// Parse optional `indexing_maps`
SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon the linalg.contract PR getting merged, this file gained the FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) helper. Would be good to deduplicate code with that.

@Groverkss
Copy link
Member

region is needed to lower to linalg.generic

This is not really true. LinalgStructuredInterface exposes getRegionBuilder which should be used to lower to linalg.generic


// At this stage of parsing the only way to infer number of region
// args is through op kind, as input output tensors are not parsed yet.
auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - in which sense are arity and category distinct concepts? (If it's just the range restriction, in what sense do we obtain both the arity and category here?)


// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
// We need to infer the `number of indexing maps` needed from the result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you (need to) infer the rank of the affine maps from the output type, not the number of maps.


SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of turning values into attributes here (which can be quite costly in a couple of respects), how about extracting the AffineMap values associated to getIndexingMaps and do an equality check on two non-attribute-based iterators?

LogicalResult ElementwiseOp::verify() {
// All necessary checks are done either by
// - EnumAttr (e.g. unknown operation kind)
// - verifyStructuredOpInterface (incorrect map, sizes).
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand: this definition is overriding the default implementation of verify that does call out to verifyStructuredOpInterface - how are the necessary affine_maps checks still performed if we don't invoke verifyStructuredOpInterface? Given that tests in invalid.mlir seem to check at least some of these properties, I am probably just not familiar with some secondary mechanism for invoking the right verifier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also quite confused. And suspect that this is the root cause of a bit weird error messages?

void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs) {
ElementwiseFn elemwiseFn;
for (auto attr : attrs) {
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe regionBuilder will only be invoked when getKind() should already give back the correct answer.


if (category == ElementwiseNAryCategory::Unary) {
result =
helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case I understand correctly, categoryAndFn.fn.unaryFn and ...fn.binaryFn and ...fn.ternaryFn are only used in this function. Could we move the logic for retrieving the actual function here and simplify/get rid of NAryCategoryAndFn?

helper.buildTernaryFn(categoryAndFn.fn.ternaryFn, block.getArgument(0),
block.getArgument(1), block.getArgument(2));
} else
assert(false && "found unhandled category in elemwise print");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this function is not about printing

func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_fn<select>
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put the affine_maps on linalg.elementwise on multiple lines? Would be easier to check/read.

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
%r = linalg.elementwise kind=#linalg.elementwise_fn<mul>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: kind is on own line everywhere else

// expected-error@+3 {{expected ::mlir::linalg::ElementwiseFn to be one of: exp, log, abs, ceil, floor}}
// expected-error@+2 {{failed to parse ElementwiseFnAttr parameter}}
// expected-error@+1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
linalg.elementwise kind=#linalg.elementwise_fn<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multi-line style used in the other files is a bit easier on the eye. Would be appreciated if you could use it in this file as well. 👍

@rolfmorel
Copy link
Contributor

rolfmorel commented Feb 3, 2025

region is needed to lower to linalg.generic

This is not really true. LinalgStructuredInterface exposes getRegionBuilder which should be used to lower to linalg.generic

Unfornately, this is not case in the world we live in 😢 Have a look at

rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
- that is, the actual region is copied from the named op.

getRegionBuilder seems to just be just a hook into the Structured Op builder API, which really runs upon op construction/build time.

@rolfmorel
Copy link
Contributor

This is looking good, @javedabsar1 ! I have left a bunch of minor comments. The only one I feel is more significant is on whether we can do with a simpler representation and C++ for functions-and-their-arities.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely patch, @javedabsar1 , thank you! Clearly in line with

and consistent with some other recent changes to Linalg (looking at indexing maps):

I've skimmed through the easier parts and left some nits.

I have one high level comment regarding test organisation. If we are to introduce a dedicated sub-dir for this Op (test/Dialect/Linalg/element-wise), then we should do the same for other Ops. I'd be OK with that and just want to make sure we test everything consistently. @rolfmorel , to me that would mean moving tests for linalg.contract to test/Dialect/Linalg/contract (*).

(*) I mean the most standard tests, e.g. "invalid.mlir" etc.

@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}

// Define the attribute enums matching elementwise op function (e.g., add).
def ElementwiseFnAttr : EnumAttr<Linalg_Dialect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that we would be switching to kind instead of Fn/function etc?

#122753 (comment)

Also, you already seem to be using kind in various places.

AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
Linalg op form which performs element-wise computation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This is repeating info from the summary, deleteme.

let description = [{
Linalg op form which performs element-wise computation.

The attribute `kind` describes the operation (e.g. add, exp). The operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
The attribute `kind` describes the operation (e.g. add, exp). The operation
The attribute `kind` describes the arithmetic operation to perform. This operation
can either be unary (e.g. max), binary (e.g. add) or ternary (i.e. select).

Comment on lines +566 to +567
Affine-maps for operands and result are required to be provided by the user
when transpose and/or broadcast is needed on any operand. When a map is not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you start with the default behaviour (e.g. "By default, all indexing maps are identities.")?
  2. Is it OK to specify only one (or two) of the maps? Or is it "either all or nothing"? Please clarify.

User-defined affine-map for operands and result must only be projected
permutations with no zero constants.

For elementwise, iterator-types are always `all parallel`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, skip For elementwise. This is duplicating info.

@@ -0,0 +1,54 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error@+3 {{expected ::mlir::linalg::ElementwiseFn to be one of: exp, log, abs, ceil, floor}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Some errors start with 'linalg.elementwise' op and some without. Please be consistent.
  2. Why is it printing the C++ name (::mlir::linalg::ElementwiseFn) rather than some MLIR equivalent?

// -----

func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
// expected-error@+3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does custom op come from? And parseNamedStructuredOpRegion? This error is a bit confusing/convoluted.

@@ -0,0 +1,88 @@
// RUN: mlir-opt %s -split-input-file | FileCheck %s
//
// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is almost identical to what I proposed in one of my comments ❤️ Could you use the same scheme throughout this PR? 🙏🏻 :)

// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
//
func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Inconsistent indentation. Same for other examples.

LogicalResult ElementwiseOp::verify() {
// All necessary checks are done either by
// - EnumAttr (e.g. unknown operation kind)
// - verifyStructuredOpInterface (incorrect map, sizes).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also quite confused. And suspect that this is the root cause of a bit weird error messages?


Example:

Defining a unary linalg.elemwise with default indexing-map:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old op's name, needs updating

@rolfmorel
Copy link
Contributor

Sorry to bring this up late, but I feel others might also have missed the elephant in the room:

How is this new linalg.elementwise/linalg.elemwise op supposed to relate to the existing linalg.elemwise_unary and linalg.elemwise_binary ops?

Somehow I completely missed that these ops already exist. TBF, the RFC and PRs seem to have left this topic out as well.

@javedabsar1, given that their existence provides a clear design alternative (i.e. extend them to support indexing maps and call it a day), could you help me/us understand the benefit of going with a new op? In particular:

  • What's the benefit of lumping these ops under just one op? That is, is there a benefit to lumping arities?
  • Is the intention to deprecate linalg.elemwise_unary and linalg.elemwise_binary?
    • If not, why not keep the IR representation much closer to the existing ops? E.g. just a fun attribute in attr-dict instead of discussing kind etc.

@rolfmorel
Copy link
Contributor

I have one high level comment regarding test organisation. If we are to introduce a dedicated sub-dir for this Op (test/Dialect/Linalg/element-wise), then we should do the same for other Ops. I'd be OK with that and just want to make sure we test everything consistently. @rolfmorel , to me that would mean moving tests for linalg.contract to test/Dialect/Linalg/contract (*).

(*) I mean the most standard tests, e.g. "invalid.mlir" etc.

I would be okay with moving tests to dedicated subdirs - at the moment linalg tests seem quite scattered around to me. With the "linalg tree" in mind, rather than having dirs per op, we could have dirs per sub-tree. That is, we could have a subdir for all named ops that perform contractions and a separate subdir for all named ops that operate elementwise. Not sure if that's the best way to go though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants