-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Extend elementwise #124661
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [ | |
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
// Join two I32EnumAttrCase lists. This joining takes care that the | ||
// 'int enum values' in the combined list do not overlap. It does this | ||
// by adding to each element of second list the offset '!size(a)'. | ||
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking out loud, perhaps it would be easier to use bit patterns instead of joining enum lists. We won't have more than 20 operations per
And set the enums above like:
etc? Then you don't need all the complex sequential logic in the parser/verifiers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the alternative join-based approach, @javedabsar1 - It's quite impressive what can done with TableGen! (For better or worse, TableGen is a programming language on its own.) I would still think an approach like @rengolin's would lead to simpler C++. The scheme I have in mind is to just shift the arity 30 bits, e.g. Three nested What do you think? (If you could just state the benefits of your approach, that would also be fine OFC.) |
||
list<I32EnumAttrCase> b> { | ||
int aSize = !size(a); | ||
list<I32EnumAttrCase> result = | ||
!foldl(a, b, acc, var, | ||
acc # [I32EnumAttrCase<var.symbol, | ||
!add(var.value, aSize) | ||
>]); | ||
} | ||
|
||
// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'. | ||
// The flattening (via call to 'join') ensures no overlap in enum values. | ||
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> { | ||
list<I32EnumAttrCase> result = | ||
!foldl([]<I32EnumAttrCase>, l, acc, var, | ||
JoinTwoI32EnumAttrCaseList<acc, var>.result); | ||
} | ||
|
||
// Define a unified `enum class : i32` for all element-wise op functions. | ||
def ElementwiseFn : | ||
I32EnumAttr<"ElementwiseFn", | ||
"", | ||
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants, | ||
BinaryFn.enumerants, | ||
TernaryFn.enumerants]>.result | ||
> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
// Define an `enum class : i32` that marks where each individual enum class | ||
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseFn. | ||
def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> { | ||
int last_unary = !size(UnaryFn.enumerants); | ||
int last_binary = !add(last_unary, !size(BinaryFn.enumerants)); | ||
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants)); | ||
|
||
let enumerants = [ | ||
I32EnumAttrCase<"LastUnary", last_unary>, | ||
I32EnumAttrCase<"LastBinary", last_binary>, | ||
I32EnumAttrCase<"LastTernary", last_ternary>]; | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
// Define an `enum class : i32` to categorise elementwise ops. | ||
def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't this just be |
||
I32EnumAttrCase<"Unary", 0>, | ||
I32EnumAttrCase<"Binary", 1>, | ||
I32EnumAttrCase<"Ternary", 2> | ||
]> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
def TypeFn : I32EnumAttr<"TypeFn", "", [ | ||
I32EnumAttrCase<"cast_signed", 0>, | ||
I32EnumAttrCase<"cast_unsigned", 1> | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -551,6 +551,122 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ | |||||||
let hasCanonicalizer = 1; | ||||||||
} | ||||||||
|
||||||||
//===----------------------------------------------------------------------===// | ||||||||
// Op definition for ElementwiseOp | ||||||||
//===----------------------------------------------------------------------===// | ||||||||
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ | ||||||||
AttrSizedOperandSegments]> { | ||||||||
let summary = [{ Performs element-wise operation }]; | ||||||||
let description = [{ | ||||||||
Linalg op form which performs element-wise computation. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] This is repeating info from the summary, deleteme. |
||||||||
|
||||||||
The attribute `kind` describes the operation (e.g. add, exp). The operation | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit]
Suggested change
|
||||||||
kind can be any elementwise nary (e.g. unary, binary) operation. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: rephrase to use arity instead of nary |
||||||||
|
||||||||
Affine-maps for operands and result are required to be provided by the user | ||||||||
when transpose and/or broadcast is needed on any operand. When a map is not | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: when a transpose ...
Comment on lines
+566
to
+567
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||
provided, default identity maps are inferred for each operand. The number | ||||||||
of dims in each of the identity maps is equal to the rank of the output type. | ||||||||
In the case of default indexing map, all input and output shapes must match. | ||||||||
User-defined affine-map for operands and result must only be projected | ||||||||
permutations with no zero constants. | ||||||||
Comment on lines
+571
to
+572
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason for this restriction? This restriction does not seem relevant here. The op is still elementwise, no matter how you define the the iteration map. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good question! Are there any considerations arising from the "linalg tree" for why the |
||||||||
|
||||||||
For elementwise, iterator-types are always `all parallel`. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, skip |
||||||||
Iterator-types are needed for constructing the underlying structured op. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why this is relevant here. This is somehow implying the op stores iterator types on the operation. This is confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on this. |
||||||||
The number of dims of the iterator-types are inferred from the rank of | ||||||||
the result type. | ||||||||
|
||||||||
Example: | ||||||||
|
||||||||
Defining a unary linalg.elemwise with default indexing-map: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Old op's name, needs updating |
||||||||
```mlir | ||||||||
%exp = linalg.elemwise | ||||||||
kind=#linalg.elemwise_fn<exp> | ||||||||
ins(%x : tensor<4x16x8xf32>) | ||||||||
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> | ||||||||
``` | ||||||||
|
||||||||
Defining a binary linalg.elemwise with user-defined indexing-map: | ||||||||
```mlir | ||||||||
%add = linalg.elemwise | ||||||||
kind=#linalg.elemwise_fn<add> | ||||||||
indexing_maps = [#transpose, #broadcast, #identity] | ||||||||
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) | ||||||||
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> | ||||||||
``` | ||||||||
}]; | ||||||||
|
||||||||
let arguments = (ins | ||||||||
Variadic<AnyShaped>:$inputs, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to allow broadcasting a scalar (e.g. as is supported by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 we should allow that |
||||||||
Variadic<AnyShaped>:$outputs, | ||||||||
ElementwiseFnAttr:$kind, | ||||||||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this should be DefaultValuedOptionalAttr. "{}" can be invalid in many cases. Instead, we should just have a builder for having a derived default value of the attribute. |
||||||||
); | ||||||||
|
||||||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors); | ||||||||
let regions = (region AnyRegion:$region); | ||||||||
let skipDefaultBuilders = 1; | ||||||||
|
||||||||
let builders = [ | ||||||||
OpBuilder< | ||||||||
(ins "ValueRange":$inputs, "ValueRange":$outputs, | ||||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), | ||||||||
[{ | ||||||||
buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we inline |
||||||||
attributes, ElementwiseOp::getRegionBuilder()); | ||||||||
}]> | ||||||||
]; | ||||||||
|
||||||||
let hasCustomAssemblyFormat = 1; | ||||||||
let hasFolder = 1; | ||||||||
let hasVerifier = 1; | ||||||||
|
||||||||
let extraClassDeclaration = structuredOpsBaseDecls # [{ | ||||||||
/// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`, | ||||||||
/// corresponding to the given fn, e.g. `ElementwiseFn::exp` | ||||||||
static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||||||||
|
||||||||
/// Both user-specified and default indexing map will always depend on | ||||||||
/// the current Op instance. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: double space |
||||||||
static bool hasDynamicIndexingMaps() { return true; } | ||||||||
|
||||||||
/// Implements the block region builder for the elementwiseOp. This is | ||||||||
/// called by the 'fillStructuredOpRegion'. | ||||||||
static void regionBuilder(ImplicitLocOpBuilder &b, | ||||||||
Block &block, ArrayRef<NamedAttribute> attrs); | ||||||||
|
||||||||
static std::function<void(ImplicitLocOpBuilder &, | ||||||||
Block &, ArrayRef<NamedAttribute>)> | ||||||||
getRegionBuilder() { | ||||||||
return regionBuilder; | ||||||||
} | ||||||||
|
||||||||
/// Returns rank of the result tensor/memref. Useful for knowing | ||||||||
/// the dimensionality of the iteration space when others means | ||||||||
/// are not possible e.g. absence of user-provided indexing map. | ||||||||
unsigned getResultRank(); | ||||||||
|
||||||||
/// Returns N 'parallel' iterator types where N is rank of result. | ||||||||
SmallVector<utils::IteratorType> getIteratorTypesArray(); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the availability of |
||||||||
|
||||||||
/// The default indexing maps are identities. | ||||||||
/// There will be N such maps, where N is the arity of the Op. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should "N" maybe be " There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and |
||||||||
static SmallVector<AffineMap> | ||||||||
getDefaultIndexingMaps(unsigned N, unsigned numDims, | ||||||||
MLIRContext *context); | ||||||||
|
||||||||
/// Destination passing style interface method. | ||||||||
::mlir::MutableOperandRange getDpsInitsMutable() { | ||||||||
return getOutputsMutable(); | ||||||||
} | ||||||||
|
||||||||
// Generic methods. | ||||||||
std::string getLibraryCallName() { | ||||||||
return generateLibraryCallName(getOperation()); | ||||||||
} | ||||||||
}]; | ||||||||
} | ||||||||
|
||||||||
//===----------------------------------------------------------------------===// | ||||||||
// Op definition for MatmulOp | ||||||||
//===----------------------------------------------------------------------===// | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that we would be switching to
kind
instead ofFn
/function
etc?#122753 (comment)
Also, you already seem to be using
kind
in various places.