Skip to content

Commit

Permalink
Further doc updates per discussion with @banach-space
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfmorel committed Jan 29, 2025
1 parent 1dc42f0 commit e8c3c1b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
41 changes: 22 additions & 19 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -696,27 +696,28 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [

`D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`

where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
dimension identifiers (meant to range over valid indices), corresponding to
the co-domains of the mandatory (projected permutation) `indexing_maps` of
`A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
indices for the dimensions in the set `dims`.
where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
identifiers - meant to range over valid indices - corresponding to the
results of the mandatory (projected permutation) `indexing_maps` for `A`,
`B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
dim identifiers).

The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
domain of each of the `affine_map`s. Like for einsums, the iteration type of
each dim is inferred and is either:

- reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
Per the above semantics, these dims will be contracted, i.e. reduced over.
- reduction: the dim is used to index into `A` and `B` but not `C`. Per the
above semantics, these dims will be contracted, i.e. reduced over.

- parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
deriving from matmul terminology - is either an "M-like" dim (if in `A`
and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
`B`, and `C`).
- parallel: the dim is used to index into `C` and at least one of `A` and
`B`, and - deriving from matmul terminology - is either an "M-like" dim
(if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
"batch"-dim (if used to index into `A`, `B`, and `C`).

For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
`H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
`n` and `b` are of parallel iteration-type) and gets represented as:
`n` and `b` have parallel iteration-type) and gets represented as:

```
%D = linalg.contract
Expand All @@ -727,12 +728,11 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
```

Note that by permuting dims in the co-domains of the `affine_map`s arbitrary
transposes can be applied to the inputs and output. Similarly, arbitrary
broadcasts can be achieved through leaving out dims on either input operand
(these dims' inferred iter type will be parallel). For example, the
following is a variant of batch-matmul where a transposition is applied to
`A` while matrix `B` gets broadcasted along the batch dimension:
Note that by permuting dims in the `affine_map`s' results, accesses to
to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
broadcasts can be achieved through leaving out dims on either input operand.
For example, the following is a variant of batch-matmul with a transposition
applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:

```
linalg.contract
Expand All @@ -744,7 +744,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
```

Numeric casting is performed on the operands to the inner multiplication,
promoting them to the same data type as the accumulator/output.
promoting/truncating them to the same data type as the accumulator/output.

TODO: Allow control over the combining/accumulating op and possibly the
multiplication op.
Expand All @@ -756,6 +756,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
AffineMapArrayAttr:$indexing_maps
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
// time - is that currently the LinalgOp interface exposes methods that
// assume a relevant region is available to be queried at any time.
let regions = (region SizedRegion<1>:$combiner);

let skipDefaultBuilders = 1;
Expand Down
32 changes: 16 additions & 16 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3697,33 +3697,33 @@ LogicalResult ContractOp::verify() {
SmallVector<size_t> inOccurrences;
SmallVector<size_t> outOccurrences;

// For each operand's affine_map and type, check that the rank of the
// affine_map's domain is the same as those seen prior, check that the
// affine_map's co-domain rank is the same as that of the corresponding type,
// check that the affine_map is a projected permutation, and, finally, update
// inputs and output occurrence counts for dims in the co-domains.
// A helper so that for each operand's affine_map and type we check that ...
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
bool isInput) -> LogicalResult {
if (iterationSpaceDims == -1) {
iterationSpaceDims = affineMap.getNumDims();
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
return emitError("iteration spaces of provided affine_maps differ");
}
// ... the affine_map is a projected permutation;
if (!affineMap.isProjectedPermutation())
return emitError("provided affine_map is not a projected permutation");

// ... the rank of the affine_map's results and corresponding type match;
if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
if (affineMap.getNumResults() != shapedType.getRank())
return emitError("ranks of shaped operand and co-domain of "
"corresponding affine_map differ");
return emitError("ranks of shaped operand and results of corresponding"
"affine_map differ");
} else if (affineMap.getNumResults() != 0) {
return emitError("affine_map specifies shaped access while operand has "
"non-shaped type");
}

if (!affineMap.isProjectedPermutation())
return emitError("provided affine_map is not a projected permutation");
// ... the rank of the affine_map's domain is the same as those seen prior;
if (iterationSpaceDims == -1) {
iterationSpaceDims = affineMap.getNumDims();
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
return emitError("iteration spaces of provided affine_maps differ");
}

// ... update counts of dims used to access either an input or the output.
for (AffineExpr affineExpr : affineMap.getResults()) {
auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
if (!affineDimExpr)
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func.func @differing_iteration_space_of_affine_maps_contraction(

func.func @mismatched_ranks_affine_map_and_operand_contraction(
%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
// expected-error @+1 {{ranks of shaped operand and results of corresponding affine_map differ}}
linalg.contract
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
Expand Down

0 comments on commit e8c3c1b

Please sign in to comment.