Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Introduce linalg.contract #123618

Merged
merged 11 commits into from
Jan 29, 2025
136 changes: 136 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,142 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}

//===----------------------------------------------------------------------===//
// Contract op.
//===----------------------------------------------------------------------===//

def ContractOp : LinalgStructuredBase_Op<"contract", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {
let summary = [{
Perform a contraction on two inputs, accumulating into the third.
}];
let description = [{
The semantics of contracting inputs `A` and `B` on top of `C` to produce
output `D` is given by

`D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
rolfmorel marked this conversation as resolved.
Show resolved Hide resolved

where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
identifiers - meant to range over valid indices - corresponding to the
results of the mandatory (projected permutation) `indexing_maps` for `A`,
`B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
dim identifiers).

The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
domain of each of the `affine_map`s. Like for einsums, the iteration type of
each dim is inferred and is either:

- reduction: the dim is used to index into `A` and `B` but not `C`. Per the
above semantics, these dims will be contracted, i.e. reduced over.

- parallel: the dim is used to index into `C` and at least one of `A` and
`B`, and - deriving from matmul terminology - is either an "M-like" dim
(if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
"batch"-dim (if used to index into `A`, `B`, and `C`).

For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
`H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
`n` and `b` have parallel iteration-type) and gets represented as:

```
%D = linalg.contract
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
```

Note that by permuting dims in the `affine_map`s' results, accesses to
to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
broadcasts can be achieved through leaving out dims on either input operand.
For example, the following is a variant of batch-matmul with a transposition
applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:

```
linalg.contract
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
affine_map<(batch, m, n, k) -> (k, n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?x?xf32>)
```

Numeric casting is performed on the operands to the inner multiplication,
promoting/truncating them to the same data type as the accumulator/output.

TODO: Allow control over the combining/accumulating op and possibly the
multiplication op.
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
// time - is that currently the LinalgOp interface exposes methods that
// assume a relevant region is available to be queried at any time.
let regions = (region SizedRegion<1>:$combiner);
rolfmorel marked this conversation as resolved.
Show resolved Hide resolved

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayAttr":$indexingMaps,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("indexing_maps", indexingMaps);
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
outputs, attributes, regionBuilder);
}]>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayAttr":$indexingMaps,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("indexing_maps", indexingMaps);
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, regionBuilder);
}]>
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare/implement functions necessary for LinalgStructuredInterface.

/// Infer iterator types for each dim in the domain of IndexingMaps.
SmallVector<utils::IteratorType> getIteratorTypesArray();

/// IndexingMaps always depends on attr associated to current Op instance.
bool hasDynamicIndexingMaps() { return true; };
bool hasUserDefinedMaps() { return true; };
rolfmorel marked this conversation as resolved.
Show resolved Hide resolved

static unsigned getNumRegionArgs();

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement function necessary for DestinationStyleOpInterface.
::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}

//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
Expand Down
Loading