Skip to content

Commit

Permalink
Use DotOp linalg lowerings for simple DotGeneral ops (openxla#2314)
Browse files Browse the repository at this point in the history
We are in the process of deprecating DotOp (openxla#2296).

To ease the deprecation, we want DotGeneral ops that are equivalent to
simple DotOp ops to be lowered to linalg in the same way.

In this PR:
- Add a "isSimpleDot" method to DotGeneralOp which determines if a
DotGeneralOp can be expressed as a DotOp.
- Reuse the DotOp linalg lowerings for DotGeneralOp when isSimpleDot.
- Remove "DotOperationType" as it is redundant with the linalg op types.

openxla#2311
  • Loading branch information
mlevesquedion authored May 9, 2024
1 parent c31ba3d commit 9d4f27d
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 80 deletions.
74 changes: 74 additions & 0 deletions stablehlo/conversions/linalg/tests/dot-product.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,77 @@ func.func @dot_dot_unsigned(%arg0: tensor<?xui32>,
// CHECK: linalg.dot
// CHECK-SAME: ins(%{{.*}} : tensor<?xi32>, tensor<?xi32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<i32>)

// -----

func.func @simple_dot_general_matmul(
%arg0: tensor<2x3xf32>,
%arg1: tensor<3x?xf32>
) -> tensor<2x?xf32> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] {someattr}
: (tensor<2x3xf32>, tensor<3x?xf32>) -> tensor<2x?xf32>
func.return %0 : tensor<2x?xf32>
}

// CHECK-LABEL: func @simple_dot_general_matmul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]])
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.matmul
// CHECK-SAME: {someattr}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)

// -----

func.func @simple_dot_general_matvec(%arg0: tensor<?x3xf32>,
%arg1: tensor<3xf32>) -> tensor<?xf32> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0]
: (tensor<?x3xf32>, tensor<3xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @simple_dot_general_matvec(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.*]] = tensor.empty(%[[D0]])
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.matvec
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)

// -----

func.func @simple_dot_general_vecmat(%arg0: tensor<3xf32>,
%arg1: tensor<3x?xf32>) -> tensor<?xf32> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0] x [0]
: (tensor<3xf32>, tensor<3x?xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @simple_dot_general_vecmat(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]])
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.vecmat
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)

// -----

func.func @simple_dot_general_dot(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>) -> tensor<f32> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0] x [0]
: (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
func.return %0 : tensor<f32>
}
// CHECK-LABEL: func @simple_dot_general_dot(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
// CHECK: %[[INIT:.*]] = tensor.empty()
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.dot
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<f32>)
Original file line number Diff line number Diff line change
Expand Up @@ -24,95 +24,91 @@ limitations under the License.

namespace mlir::stablehlo {
namespace {
enum class DotOperationType {
kVectorDot = 0,
kMatrixVector,
kVectorMatrix,
kMatrixMatrix,
kUnsupported
};

DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) {
ArrayRef<int64_t> lhsShape = dotOp.getLhs().getType().getShape();
ArrayRef<int64_t> rhsShape = dotOp.getRhs().getType().getShape();
auto shapeMatches = [](int64_t a, int64_t b) {
template <typename LinalgOpTy, typename StablehloOpTy>
bool opMatchesLinalgTarget(StablehloOpTy op) {
ArrayRef<int64_t> lhsShape = op.getLhs().getType().getShape();
ArrayRef<int64_t> rhsShape = op.getRhs().getType().getShape();
auto areCompatible = [](int64_t a, int64_t b) {
return a == ShapedType::kDynamic || b == ShapedType::kDynamic || a == b;
};
if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
shapeMatches(lhsShape[0], rhsShape[0])) {
return DotOperationType::kVectorDot;
areCompatible(lhsShape[0], rhsShape[0])) {
return std::is_same<LinalgOpTy, linalg::DotOp>::value;
}
if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::kMatrixVector;
areCompatible(lhsShape[1], rhsShape[0])) {
return std::is_same<LinalgOpTy, linalg::MatvecOp>::value;
}
if (lhsShape.size() == 1 && rhsShape.size() == 2 &&
shapeMatches(lhsShape[0], rhsShape[0])) {
return DotOperationType::kVectorMatrix;
areCompatible(lhsShape[0], rhsShape[0])) {
return std::is_same<LinalgOpTy, linalg::VecmatOp>::value;
}
if (lhsShape.size() == 2 && rhsShape.size() == 2 &&
shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::kMatrixMatrix;
areCompatible(lhsShape[1], rhsShape[0])) {
return std::is_same<LinalgOpTy, linalg::MatmulOp>::value;
}
return DotOperationType::kUnsupported;
return false;
}

template <typename LinalgOp>
SmallVector<Value, 2> getDotOpEmptyTensorDynSizes(OpBuilder &b, Location loc,
Value lhs, Value rhs,
DotOperationType type) {
Value lhs, Value rhs) {
SmallVector<Value, 2> dynShape;
switch (type) {
case DotOperationType::kMatrixMatrix: {
if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kMatrixVector: {
if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
break;
}
case DotOperationType::kVectorMatrix: {
if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kVectorDot:
case DotOperationType::kUnsupported:
break;

auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());

auto lhsIsMatrix = std::is_same<LinalgOp, linalg::MatvecOp>::value;
auto rhsIsMatrix = std::is_same<LinalgOp, linalg::VecmatOp>::value;
if (std::is_same<LinalgOp, linalg::MatmulOp>::value) {
lhsIsMatrix = rhsIsMatrix = true;
}

if (lhsIsMatrix && lhsType.isDynamicDim(0))
dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
if (rhsIsMatrix && rhsType.isDynamicDim(1))
dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
return dynShape;
}

template <DotOperationType op_type, typename LinalgOp>
template <typename OpTy, typename OpAdaptor, typename LinalgOpTy>
LogicalResult lowerDotOp(ConversionPatternRewriter &rewriter,
const TypeConverter *typeConverter, OpTy op,
OpAdaptor adaptor) {
if (!opMatchesLinalgTarget<LinalgOpTy>(op)) return failure();

auto loc = op.getLoc();

// Convert unsigned to signed. This works because signed and unsigned
// integer matmul is the same operation in two's complement.
auto outputType = cast<ShapedType>(typeConverter->convertType(op.getType()));

SmallVector<Value, 2> dynShape = getDotOpEmptyTensorDynSizes<LinalgOpTy>(
rewriter, loc, adaptor.getLhs(), adaptor.getRhs());

Value emptyTensor =
!sparse_tensor::getSparseTensorEncoding(outputType)
? getEmptyTensor(rewriter, loc, outputType, dynShape)
: getEmptySparseTensor(rewriter, loc, outputType, dynShape);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);

rewriter.replaceOpWithNewOp<LinalgOpTy>(
op, TypeRange{outputType}, ValueRange{adaptor.getLhs(), adaptor.getRhs()},
ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op));
return success();
}

template <typename LinalgOpTy>
struct DotOpConversion final : OpConversionPattern<mlir::stablehlo::DotOp> {
using OpConversionPattern<mlir::stablehlo::DotOp>::OpConversionPattern;
using OpAdaptor = mlir::stablehlo::DotOp::Adaptor;

LogicalResult matchAndRewrite(
mlir::stablehlo::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (getDotOperationType(op) != op_type) return failure();

Location loc = op.getLoc();
// Convert unsigned to signed. This works because signed and unsigned
// integer matmul is the same operation in two's complement.
auto outputType =
cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
SmallVector<Value, 2> dynShape = getDotOpEmptyTensorDynSizes(
rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type);
Value emptyTensor =
!sparse_tensor::getSparseTensorEncoding(outputType)
? getEmptyTensor(rewriter, loc, outputType, dynShape)
: getEmptySparseTensor(rewriter, loc, outputType, dynShape);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
rewriter.replaceOpWithNewOp<LinalgOp>(
op, TypeRange{outputType},
ValueRange{adaptor.getLhs(), adaptor.getRhs()}, ValueRange{zeroTensor},
linalg::getPrunedAttributeList(op));
return success();
return lowerDotOp<DotOp, OpAdaptor, LinalgOpTy>(
rewriter, getTypeConverter(), op, adaptor);
}
};

Expand Down Expand Up @@ -177,6 +173,26 @@ struct DotGeneralOpConversion final
LogicalResult matchAndRewrite(
mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (op.isSimpleDot()) {
if (succeeded(lowerDotOp<DotGeneralOp, OpAdaptor, linalg::MatmulOp>(
rewriter, getTypeConverter(), op, adaptor)))
return success();
if (succeeded(lowerDotOp<DotGeneralOp, OpAdaptor, linalg::MatvecOp>(
rewriter, getTypeConverter(), op, adaptor)))
return success();
if (succeeded(lowerDotOp<DotGeneralOp, OpAdaptor, linalg::VecmatOp>(
rewriter, getTypeConverter(), op, adaptor)))
return success();
if (succeeded(lowerDotOp<DotGeneralOp, OpAdaptor, linalg::DotOp>(
rewriter, getTypeConverter(), op, adaptor)))
return success();
std::string str;
llvm::raw_string_ostream os(str);
os << "supposedly simple DotGeneralOp could not be converted: ";
op.print(os);
llvm::report_fatal_error(str.c_str());
}

// Get various dimension iterator information
mlir::stablehlo::DotDimensionNumbersAttr dimNumbers =
op.getDotDimensionNumbers();
Expand Down Expand Up @@ -270,13 +286,11 @@ void populateStablehloDotProdToLinalgConversionPatterns(
RewritePatternSet *patterns) {
// Ensure specialized patterns are higher priority than their generic
// versions.
patterns
->add<DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralBatchMatMulOpConversion>(typeConverter, context,
PatternBenefit(2));
patterns->add<
DotOpConversion<linalg::MatmulOp>, DotOpConversion<linalg::MatvecOp>,
DotOpConversion<linalg::VecmatOp>, DotOpConversion<linalg::DotOp>,
DotGeneralBatchMatMulOpConversion>(typeConverter, context,
PatternBenefit(2));
patterns->add<DotGeneralOpConversion>(typeConverter, context,
PatternBenefit(1));
}
Expand Down
15 changes: 15 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,21 @@ mlir::Speculation::Speculatability DotGeneralOp::getSpeculatability() {
return mlir::Speculation::Speculatable;
}

DotDimensionNumbersAttr getDefaultDotDimensionNumbers(mlir::Value lhs) {
return DotDimensionNumbersAttr::get(
lhs.getContext(),
/*lhsBatchingDimensions=*/{},
/*rhsBatchingDimensions=*/{},
/*lhsContractingDimensions=*/
{cast<ShapedType>(lhs.getType()).getRank() - 1},
/*rhsContractingDimensions=*/{0});
}

bool DotGeneralOp::isSimpleDot() {
return getDotDimensionNumbersAttr() ==
getDefaultDotDimensionNumbers(getLhs());
}

//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ ParseResult parseWindowAttributes(OpAsmParser &parser, Attribute &windowStrides,
namespace mlir {
namespace stablehlo {

// Returns the dimension numbers for a DotGeneral op that can be expressed as
// a DotOp, given the LHS of such an operation.
DotDimensionNumbersAttr getDefaultDotDimensionNumbers(mlir::Value lhs);

SortOp createSortOp(PatternRewriter *rewriter, const Location &loc,
const llvm::ArrayRef<Value> &operands,
const llvm::ArrayRef<Type> &elementTypes, int64_t dimension,
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,9 @@ def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general",
let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();

/// Determines if this DotGeneral instance can be expressed as a DotOp.
bool isSimpleDot();
}];
}

Expand Down
10 changes: 0 additions & 10 deletions stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue) {
llvm::report_fatal_error("unsupported type");
}

DotDimensionNumbersAttr getDefaultDotDimensionNumbers(mlir::Value dotOpLhs) {
return DotDimensionNumbersAttr::get(
dotOpLhs.getContext(),
/*lhsBatchingDimensions=*/{},
/*rhsBatchingDimensions=*/{},
/*lhsContractingDimensions=*/
{cast<ShapedType>(dotOpLhs.getType()).getRank() - 1},
/*rhsContractingDimensions=*/{0});
}

DenseI64ArrayAttr getBroadcastDimensions(RankedTensorType resultType,
DenseI64ArrayAttr broadcastSizes) {
int64_t operandRank = resultType.getRank() - broadcastSizes.size();
Expand Down

0 comments on commit 9d4f27d

Please sign in to comment.