diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 95ebd4e9fe3d9..239dc9aa1de6f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -11,40 +11,43 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/VectorInterfaces.h" #define DEBUG_TYPE "vector-shape-cast-lowering" using namespace mlir; using namespace mlir::vector; +/// Increments n-D `indices` by `step` starting from the innermost dimension. +static void incIdx(SmallVectorImpl &indices, VectorType vecType, + int step = 1) { + for (int dim : llvm::reverse(llvm::seq(0, indices.size()))) { + assert(indices[dim] < vecType.getDimSize(dim) && + "Indices are out of bound"); + indices[dim] += step; + if (indices[dim] < vecType.getDimSize(dim)) + break; + + indices[dim] = 0; + step = 1; + } +} + namespace { -/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D -/// vectors progressively on the way to target llvm.matrix intrinsics. -/// This iterates over the most major dimension of the 2-D vector and performs -/// rewrites into: -/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D -class ShapeCastOp2DDownCastRewritePattern +/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D +/// vectors progressively. This iterates over the n-1 major dimensions of the +/// n-D vector and performs rewrites into: +/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D +class ShapeCastOpNDDownCastRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -53,35 +56,52 @@ class ShapeCastOp2DDownCastRewritePattern PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.isScalable() || resultVectorType.isScalable()) return failure(); - if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) + int64_t srcRank = sourceVectorType.getRank(); + int64_t resRank = resultVectorType.getRank(); + if (srcRank < 2 || resRank != 1) return failure(); + // Compute the number of 1-D vector elements involved in the reshape. + int64_t numElts = 1; + for (int64_t dim = 0; dim < srcRank - 1; ++dim) + numElts *= sourceVectorType.getDimSize(dim); + auto loc = op.getLoc(); - Value desc = rewriter.create( + SmallVector srcIdx(srcRank - 1, 0); + SmallVector resIdx(resRank, 0); + int64_t extractSize = sourceVectorType.getShape().back(); + Value result = rewriter.create( loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; - for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { - Value vec = rewriter.create(loc, op.getSource(), i); - desc = rewriter.create( - loc, vec, desc, - /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); + + // Compute the indices of each 1-D vector element of the source extraction + // and destination slice insertion and generate such instructions. + for (int64_t i = 0; i < numElts; ++i) { + if (i != 0) { + incIdx(srcIdx, sourceVectorType, /*step=*/1); + incIdx(resIdx, resultVectorType, /*step=*/extractSize); + } + + Value extract = + rewriter.create(loc, op.getSource(), srcIdx); + result = rewriter.create( + loc, extract, result, + /*offsets=*/resIdx, /*strides=*/1); } - rewriter.replaceOp(op, desc); + + rewriter.replaceOp(op, result); return success(); } }; -/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D -/// vectors progressively. -/// This iterates over the most major dimension of the 2-D vector and performs -/// rewrites into: -/// vector.extract_strided_slice from 1-D + vector.insert into 2-D +/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D +/// vectors progressively. This iterates over the n-1 major dimension of the n-D +/// vector and performs rewrites into: +/// vector.extract_strided_slice from 1-D + vector.insert into n-D /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. -class ShapeCastOp2DUpCastRewritePattern +class ShapeCastOpNDUpCastRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -90,43 +110,43 @@ class ShapeCastOp2DUpCastRewritePattern PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.isScalable() || resultVectorType.isScalable()) return failure(); - if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) + int64_t srcRank = sourceVectorType.getRank(); + int64_t resRank = resultVectorType.getRank(); + if (srcRank != 1 || resRank < 2) return failure(); + // Compute the number of 1-D vector elements involved in the reshape. + int64_t numElts = 1; + for (int64_t dim = 0; dim < resRank - 1; ++dim) + numElts *= resultVectorType.getDimSize(dim); + + // Compute the indices of each 1-D vector element of the source slice + // extraction and destination insertion and generate such instructions. auto loc = op.getLoc(); - Value desc = rewriter.create( + SmallVector srcIdx(srcRank, 0); + SmallVector resIdx(resRank - 1, 0); + int64_t extractSize = resultVectorType.getShape().back(); + Value result = rewriter.create( loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; - for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { - Value vec = rewriter.create( - loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, - /*sizes=*/mostMinorVectorSize, + for (int64_t i = 0; i < numElts; ++i) { + if (i != 0) { + incIdx(srcIdx, sourceVectorType, /*step=*/extractSize); + incIdx(resIdx, resultVectorType, /*step=*/1); + } + + Value extract = rewriter.create( + loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize, /*strides=*/1); - desc = rewriter.create(loc, vec, desc, i); + result = rewriter.create(loc, extract, result, resIdx); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, result); return success(); } }; -static void incIdx(llvm::MutableArrayRef idx, VectorType tp, - int dimIdx, int initialStep = 1) { - int step = initialStep; - for (int d = dimIdx; d >= 0; d--) { - idx[d] += step; - if (idx[d] >= tp.getDimSize(d)) { - idx[d] = 0; - step = 1; - } else { - break; - } - } -} - // We typically should not lower general shape cast operations into data // movement instructions, since the assumption is that these casts are // optimized away during progressive lowering. For completeness, however, @@ -145,18 +165,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { if (sourceVectorType.isScalable() || resultVectorType.isScalable()) return failure(); - // Special case 2D / 1D lowerings with better implementations. - // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. + // Special case for n-D / 1-D lowerings with better implementations. int64_t srcRank = sourceVectorType.getRank(); int64_t resRank = resultVectorType.getRank(); - if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) + if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1)) return failure(); // Generic ShapeCast lowering path goes all the way down to unrolled scalar // extract/insert chains. - // TODO: consider evolving the semantics to only allow 1D source or dest and - // drop this potentially very expensive lowering. - // Compute number of elements involved in the reshape. int64_t numElts = 1; for (int64_t r = 0; r < srcRank; r++) numElts *= sourceVectorType.getDimSize(r); @@ -166,14 +182,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { // x[0,1,0] = y[0,2] // etc., incrementing the two index vectors "row-major" // within the source and result shape. - SmallVector srcIdx(srcRank); - SmallVector resIdx(resRank); + SmallVector srcIdx(srcRank, 0); + SmallVector resIdx(resRank, 0); Value result = rewriter.create( loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); for (int64_t i = 0; i < numElts; i++) { if (i != 0) { - incIdx(srcIdx, sourceVectorType, srcRank - 1); - incIdx(resIdx, resultVectorType, resRank - 1); + incIdx(srcIdx, sourceVectorType); + incIdx(resIdx, resultVectorType); } Value extract; @@ -252,7 +268,7 @@ class ScalableShapeCastOpRewritePattern // have a single trailing scalable dimension. This is because there are no // legal representation of other scalable types in LLVM (and likely won't be // soon). There are also (currently) no operations that can index or extract - // from >= 2D scalable vectors or scalable vectors of fixed vectors. + // from >= 2-D scalable vectors or scalable vectors of fixed vectors. if (!isTrailingDimScalable(sourceVectorType) || !isTrailingDimScalable(resultVectorType)) { return failure(); @@ -278,8 +294,8 @@ class ScalableShapeCastOpRewritePattern Value result = rewriter.create( loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - SmallVector srcIdx(srcRank); - SmallVector resIdx(resRank); + SmallVector srcIdx(srcRank, 0); + SmallVector resIdx(resRank, 0); // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) // once D150000 lands. @@ -334,8 +350,8 @@ class ScalableShapeCastOpRewritePattern // 4. Increment the insert/extract indices, stepping by minExtractionSize // for the trailing dimensions. - incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize); - incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize); + incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize); + incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize); } rewriter.replaceOp(op, result); @@ -352,8 +368,8 @@ class ScalableShapeCastOpRewritePattern void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index f2f1211fd70ee..b4c52d5533116 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s // CHECK-LABEL: func @nop_shape_cast // CHECK-SAME: %[[A:.*]]: vector<16xf32> @@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { // CHECK-LABEL: func @shape_cast_3d1d // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> -// CHECK: return %[[T11]] : vector<6xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]] +// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32> +// CHECK: return %[[T5]] : vector<6xf32> func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> @@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { // CHECK-LABEL: func @shape_cast_1d3d // CHECK-SAME: %[[A:.*]]: vector<6xf32> // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: return %[[T11]] : vector<2x1x3xf32> +// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]] +// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32> +// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]] +// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32> +// CHECK: return %[[T3]] : vector<2x1x3xf32> func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>