Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Extend elementwise #124661

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}

// Define the attribute enums matching elementwise op function (e.g., add).
def ElementwiseFnAttr : EnumAttr<Linalg_Dialect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that we would be switching to kind instead of Fn/function etc?

#122753 (comment)

Also, you already seem to be using kind in various places.

ElementwiseFn, "elementwise_fn"> {
let assemblyFormat = "`<` $value `>`";
}

// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
Expand Down
59 changes: 59 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Join two I32EnumAttrCase lists. This joining takes care that the
// 'int enum values' in the combined list do not overlap. It does this
// by adding to each element of second list the offset '!size(a)'.
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud, perhaps it would be easier to use bit patterns instead of joining enum lists. We won't have more than 20 operations per Category, so:

  • Unary: (op | (0xFF << 1))
  • Binary: (op | (0xFF << 2))
  • Ternary: (op | (0xFF << 3))

And set the enums above like:

  • I32EnumAttrCase<"log", (1 << 1)>
  • I32EnumAttrCase<"sub", (1 << 2)>
  • I32EnumAttrCase<"select", (1 << 3)>

etc?

Then you don't need all the complex sequential logic in the parser/verifiers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need anymore.

Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the alternative join-based approach, @javedabsar1 - It's quite impressive what can done with TableGen! (For better or worse, TableGen is a programming language on its own.)

I would still think an approach like @rengolin's would lead to simpler C++. The scheme I have in mind is to just shift the arity 30 bits, e.g. I32EnumAttrCase<"abs", 2> becomes I32EnumAttrCase<"abs", 2 + (1 << 30)>,
I32EnumAttrCase<"div", 3 + (2 << 30)>
I32EnumAttrCase<"select", 0 + (3 << 30)>. This way the arity can be retrieved by just shifting right 30 bits (e.g. derivedEnumVal >> 30) and to obtain the original op code you just do derivedEnumVal & ((1 << 30) - 1).

Three nested !lfolds should now suffice to derive all the derived enum cases and ElementwiseFnLimits could go and NAryCategoryAndFn could go or be simplified.

What do you think? (If you could just state the benefits of your approach, that would also be fine OFC.)

list<I32EnumAttrCase> b> {
int aSize = !size(a);
list<I32EnumAttrCase> result =
!foldl(a, b, acc, var,
acc # [I32EnumAttrCase<var.symbol,
!add(var.value, aSize)
>]);
}

// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
// The flattening (via call to 'join') ensures no overlap in enum values.
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
list<I32EnumAttrCase> result =
!foldl([]<I32EnumAttrCase>, l, acc, var,
JoinTwoI32EnumAttrCaseList<acc, var>.result);
}

// Define a unified `enum class : i32` for all element-wise op functions.
def ElementwiseFn :
I32EnumAttr<"ElementwiseFn",
"",
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
BinaryFn.enumerants,
TernaryFn.enumerants]>.result
> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` that marks where each individual enum class
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseFn.
def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> {
int last_unary = !size(UnaryFn.enumerants);
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));

let enumerants = [
I32EnumAttrCase<"LastUnary", last_unary>,
I32EnumAttrCase<"LastBinary", last_binary>,
I32EnumAttrCase<"LastTernary", last_ternary>];
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` to categorise elementwise ops.
def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this just be ElementwiseArity? The concept of NAryCategory is exactly arity, right?

I32EnumAttrCase<"Unary", 0>,
I32EnumAttrCase<"Binary", 1>,
I32EnumAttrCase<"Ternary", 2>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
Expand Down
116 changes: 116 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,122 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for ElementwiseOp
//===----------------------------------------------------------------------===//
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
Linalg op form which performs element-wise computation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This is repeating info from the summary, deleteme.


The attribute `kind` describes the operation (e.g. add, exp). The operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
The attribute `kind` describes the operation (e.g. add, exp). The operation
The attribute `kind` describes the arithmetic operation to perform. This operation
can either be unary (e.g. max), binary (e.g. add) or ternary (i.e. select).

kind can be any elementwise nary (e.g. unary, binary) operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rephrase to use arity instead of nary


Affine-maps for operands and result are required to be provided by the user
when transpose and/or broadcast is needed on any operand. When a map is not
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: when a transpose ...

Comment on lines +566 to +567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you start with the default behaviour (e.g. "By default, all indexing maps are identities.")?
  2. Is it OK to specify only one (or two) of the maps? Or is it "either all or nothing"? Please clarify.

provided, default identity maps are inferred for each operand. The number
of dims in each of the identity maps is equal to the rank of the output type.
In the case of default indexing map, all input and output shapes must match.
User-defined affine-map for operands and result must only be projected
permutations with no zero constants.
Comment on lines +571 to +572
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for this restriction? This restriction does not seem relevant here. The op is still elementwise, no matter how you define the the iteration map.

Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question! Are there any considerations arising from the "linalg tree" for why the indexing_maps should be projected permutations?


For elementwise, iterator-types are always `all parallel`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: all parallel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, skip For elementwise. This is duplicating info.

Iterator-types are needed for constructing the underlying structured op.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is relevant here. This is somehow implying the op stores iterator types on the operation. This is confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on this.

The number of dims of the iterator-types are inferred from the rank of
the result type.

Example:

Defining a unary linalg.elemwise with default indexing-map:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old op's name, needs updating

```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_fn<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```

Defining a binary linalg.elemwise with user-defined indexing-map:
```mlir
%add = linalg.elemwise
kind=#linalg.elemwise_fn<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
```
}];

let arguments = (ins
Variadic<AnyShaped>:$inputs,
Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow broadcasting a scalar (e.g. as is supported by generic, contract and matmul)? If so, this AnyShaped should probably be AnyType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should allow that

Variadic<AnyShaped>:$outputs,
ElementwiseFnAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this should be DefaultValuedOptionalAttr. "{}" can be invalid in many cases. Instead, we should just have a builder for having a derived default value of the attribute.

);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;

let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we inline buildElementwiseOp here? As buildElementwiseOp is just forwarding args to buildStructuredOp, it seems the separate function (not close to any other elementwise op code) is unnecessary.

attributes, ElementwiseOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`,
/// corresponding to the given fn, e.g. `ElementwiseFn::exp`
static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ElementwiseNAryCategory just encodes the arity, do we need a separate enum for this? Would just an unsigned int suffice? (The thing we would lose is that the enum encodes/enforce arities are restricted to unary, binary and ternary - maybe that's sufficient reason to keep it.)


/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: double space

static bool hasDynamicIndexingMaps() { return true; }

/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

/// Returns rank of the result tensor/memref. Useful for knowing
/// the dimensionality of the iteration space when others means
/// are not possible e.g. absence of user-provided indexing map.
unsigned getResultRank();

/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the availability of getResultRank(), sounds like the definition could be here in the .td file.


/// The default indexing maps are identities.
/// There will be N such maps, where N is the arity of the Op.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should "N" maybe be "arity"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and numDims could be rank, right?

static SmallVector<AffineMap>
getDefaultIndexingMaps(unsigned N, unsigned numDims,
MLIRContext *context);

/// Destination passing style interface method.
::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}

//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
Loading