Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/onnx-to-tosa' into bump_…
Browse files Browse the repository at this point in the history
…to_86dbaf04
  • Loading branch information
jorickert committed Feb 7, 2025
2 parents a4aa907 + f4d1ac4 commit 64df722
Show file tree
Hide file tree
Showing 25 changed files with 565 additions and 136 deletions.
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ project(onnx-mlir)
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
Expand Down Expand Up @@ -223,12 +222,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
endif()

if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 1)
else()
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 0)
endif()

add_subdirectory(utils)
add_subdirectory(include)
Expand Down
7 changes: 7 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ bool enableParallel; // onnx-mlir only
bool disableSimdOption; // onnx-mlir only
bool enableFastMathOption; // onnx-mlir only
bool disableRecomposeOption; // onnx-mlir only
bool disableConvTransposeDecomposeOption; // onnx-mlir only
bool enableSimdDataLayout; // onnx-mlir only
bool verifyInputTensors; // onnx-mlir only
bool allowSorting; // onnx-mlir only
Expand Down Expand Up @@ -257,6 +258,12 @@ static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> disableConvTranposeDecomposeOptionOpt(
"disable-convtranspose-decompose",
llvm::cl::desc("Disable decomposition of ONNX ConvTranspose operator."),
llvm::cl::location(disableConvTransposeDecomposeOption),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));

// Options for onnx-mlir only
static llvm::cl::opt<EmissionTargetType, true> emissionTargetOpt(
llvm::cl::desc("Choose target to emit:"),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ extern bool enableParallel; // onnx-mlir only
extern bool disableSimdOption; // onnx-mlir only
extern bool enableFastMathOption; // onnx-mlir only
extern bool disableRecomposeOption; // onnx-mlir only
extern bool disableConvTransposeDecomposeOption; // onnx-mlir only
extern bool enableSimdDataLayout; // onnx-mlir only
extern bool verifyInputTensors; // onnx-mlir only
extern bool allowSorting; // onnx-mlir only
Expand Down
58 changes: 35 additions & 23 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ Value TosaBuilder::createConst(
}

bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) {
if (llvm::any_of(valueRange, [](const auto value) {
return !mlir::cast<ShapedType>(value.getType()).hasRank();
})) {
return false; // we have no way to determine the broadcast, so do not
// attempt it
}
int64_t firstRank = mlir::cast<ShapedType>(valueRange[0].getType()).getRank();
for (Value operand : valueRange) {
auto operandType = mlir::cast<ShapedType>(operand.getType());
Expand Down Expand Up @@ -129,9 +135,8 @@ Value TosaBuilder::getConst(ArrayRef<float> vec, ArrayRef<int64_t> shape) {
return constOp;
}

Value TosaBuilder::getSplattedConst(
float val, Type dtype, llvm::ArrayRef<int64_t> shape) {
auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type());
Value TosaBuilder::getSplattedConst(float val, Type dtype, int64_t rank) {
auto constType = tosa::reduceAxisToOne(rank, rewriter().getF32Type());
auto constAttr = DenseElementsAttr::get(constType, val);

auto constOp =
Expand All @@ -150,8 +155,7 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {
auto valueType = mlir::cast<ShapedType>(value.getType());
// get new value type
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(
valueType.getShape().size(), ShapedType::kDynamic),
llvm::SmallVector<int64_t, 4>(perm.size(), ShapedType::kDynamic),
valueType.getElementType());
// create transpose for value
Value newValue = tosa::CreateOpAndInfer<mlir::tosa::TransposeOp>(
Expand Down Expand Up @@ -195,9 +199,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
rhs = valueVec[1];
}
auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
}
Expand All @@ -215,9 +222,12 @@ Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
}

auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsElementType);
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsElementType);
return tosa::CreateOpAndInfer<mlir::tosa::IntDivOp>(
rewriter(), loc(), newValueType, lhs, rhs);
}
Expand All @@ -230,9 +240,12 @@ Value TosaBuilder::binaryOp(Value &lhs, Value &rhs) {
rhs = valueVec[1];
}
auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), newValueType, lhs, rhs);
}

Expand All @@ -246,11 +259,7 @@ template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(

template <typename T>
Value TosaBuilder::unaryOp(mlir::Value &input) {
auto inputType = cast<ShapedType>(input.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(inputType.getRank(), ShapedType::kDynamic),
inputType.getElementType());
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), newValueType, input);
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), input.getType(), input);
}

template Value TosaBuilder::unaryOp<mlir::tosa::ExpOp>(mlir::Value &input);
Expand Down Expand Up @@ -305,9 +314,12 @@ Value TosaBuilder::select(
rhs = valueVec[2];
}
auto lhsType = cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::SelectOp>(
rewriter(), loc(), newValueType, cond, lhs, rhs);
}
Expand All @@ -328,7 +340,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType(
Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = cast<ShapedType>(input.getType());
auto oneHalf = this->getSplattedConst(
0.5, inputType.getElementType(), inputType.getShape());
0.5, inputType.getElementType(), inputType.getRank());
return this->binaryOp<mlir::tosa::PowOp>(input, oneHalf);
}

Expand Down
3 changes: 1 addition & 2 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ struct TosaBuilder : DialectBuilder {
// The tensor will have the same rank as shape but all dimensions will
// have size 1 (differs from tensorflow impl.)
// If dtype is provided, it also cast the value to the appropriate dtype.
mlir::Value getSplattedConst(
float val, mlir::Type dtype, llvm::ArrayRef<int64_t> shape = {});
mlir::Value getSplattedConst(float val, mlir::Type dtype, int64_t rank);

// Creates a constant of shape <1x1x...x1> of rank `rank` with all values set
// to `value`.
Expand Down
65 changes: 48 additions & 17 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(
return success();
}

namespace {
template <typename OnnxOp>
void copySingleResultType(OnnxOp opToCopyFrom, Value &valueToCopyTo) {
assert(opToCopyFrom->getNumResults() == 1);
valueToCopyTo.setType(opToCopyFrom->getResult(0).getType());
}
} // namespace

// Element-wise unary ops lowering to TOSA dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOpONNX, typename ElementwiseUnaryOpTOSA,
Expand Down Expand Up @@ -197,6 +205,7 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {

TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value mulOp = tosaBuilder.mul(lhs, rhs);
copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});

return success();
Expand Down Expand Up @@ -230,12 +239,12 @@ static LogicalResult legalizeFloatingPointPrelu(Operation *op,
auto loc = op->getLoc();
TosaBuilder tosaBuilder(rewriter, loc);
Value constZero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

auto mul = tosaBuilder.mul(input, alphaOrSlope);
auto greaterEqual = tosaBuilder.greaterEqual(input, constZero);
auto select = tosaBuilder.select(greaterEqual, input, mul);

copySingleResultType(op, select);
rewriter.replaceOp(op, {select});
return success();
}
Expand Down Expand Up @@ -265,7 +274,7 @@ class ONNXLeakyReluOpLoweringToTOSA
TosaBuilder tosaBuilder(rewriter, loc);
return legalizeFloatingPointPrelu(op, rewriter, adaptor.getX(),
tosaBuilder.getSplattedConst(
alpha, outputType.getElementType(), outputType.getShape()),
alpha, outputType.getElementType(), outputType.getRank()),
outputType);
}
};
Expand Down Expand Up @@ -303,6 +312,7 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern<OnnxCompOp> {
} else if constexpr (std::is_same_v<OnnxCompOp, ONNXLessOp>) {
res = tosaBuilder.less(input1, input2);
}
copySingleResultType(op, res);
rewriter.replaceOp(op, {res});
return success();
}
Expand Down Expand Up @@ -384,7 +394,7 @@ class ONNXCastOpLoweringToTOSA : public OpConversionPattern<ONNXCastOp> {
// onnx.Cast and tosa.cast.
if (resultTy.getElementType().getIntOrFloatBitWidth() != 1) {
auto zero = tosaBuilder.getSplattedConst(
0.0f, inputTy.getElementType(), resultTy.getShape());
0.0f, inputTy.getElementType(), resultTy.getRank());
auto positive = tosaBuilder.greaterEqual(input, zero);

auto floor = tosaBuilder.unaryOp<mlir::tosa::FloorOp>(input);
Expand Down Expand Up @@ -412,13 +422,15 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {

if (isa<IntegerType>(resultElementType)) {
Value divOp = tosaBuilder.intdiv(lhs, rhs);
copySingleResultType(op, divOp);
rewriter.replaceOp(op, {divOp});
return success();
}
// For floating point types, decompose ONNXDivOp into
// tosa::ReciprocalOp and tosa::MulOp.
Value reciprocalOp = tosaBuilder.unaryOp<mlir::tosa::ReciprocalOp>(rhs);
Value mulOp = tosaBuilder.mul(lhs, reciprocalOp);
copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});
return success();
}
Expand Down Expand Up @@ -463,20 +475,21 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern<ONNXEluOp> {
TosaBuilder tosaBuilder(rewriter, op->getLoc());

Value one = tosaBuilder.getSplattedConst(
1.0, resultTensorType.getElementType(), resultTensorType.getShape());
1.0, resultTensorType.getElementType(), resultTensorType.getRank());
Value alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
resultTensorType.getElementType(), resultTensorType.getShape());
resultTensorType.getElementType(), resultTensorType.getRank());
Value constZero = tosaBuilder.getSplattedConst(
0.0, resultTensorType.getElementType(), resultTensorType.getShape());
0.0, resultTensorType.getElementType(), resultTensorType.getRank());

Value exp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
copySingleResultType(op, exp);
Value expMinusOne = tosaBuilder.binaryOp<mlir::tosa::SubOp>(exp, one);
Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha);
Value greaterEqual = tosaBuilder.greaterEqual(input, constZero);
auto select =
tosaBuilder.select(greaterEqual, input, alphaTimesExpMinusOne);

copySingleResultType(op, select);
rewriter.replaceOp(op, {select});
return success();
}
Expand Down Expand Up @@ -507,11 +520,16 @@ class ONNXHardSigmoidOpLoweringToTOSA
APFloat oneOverAlpha(alpha.getSemantics(), 1);
oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

if (!resultType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "HardSigmoid: Static shape required to create splatted const");
}

Value constBetaOverAlpha =
tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(),
resultElementType, resultType.getShape());
resultElementType, resultType.getRank());
Value constAlpha = tosaBuilder.getSplattedConst(
alpha.convertToDouble(), resultElementType, resultType.getShape());
alpha.convertToDouble(), resultElementType, resultType.getRank());

auto addOp =
tosaBuilder.binaryOp<mlir::tosa::AddOp>(input, constBetaOverAlpha);
Expand All @@ -521,7 +539,7 @@ class ONNXHardSigmoidOpLoweringToTOSA
rewriter.getF32FloatAttr(0),
rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble()));
auto mulOp = tosaBuilder.mul(clampOp, constAlpha);

copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});
return success();
}
Expand Down Expand Up @@ -556,14 +574,19 @@ class ONNXSoftplusOpLoweringToTOSA
if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) {
return failure();
}
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXSoftplusOp: Rank required to create splatted const");
}

Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
auto one = tosaBuilder.getSplattedConst(
1.0, outputType.getElementType(), outputType.getShape());
1.0, outputType.getElementType(), outputType.getRank());

auto expOp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
copySingleResultType(op, expOp);
auto expPlusOne = tosaBuilder.binaryOp<mlir::tosa::AddOp>(expOp, one);
auto logOp = tosaBuilder.unaryOp<mlir::tosa::LogOp>(expPlusOne);
rewriter.replaceOp(op, {logOp});
Expand All @@ -585,15 +608,19 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern<ONNXSeluOp> {
Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXSeluOp: Rank required to create splatted const");
}

Value alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
Value gamma =
tosaBuilder.getSplattedConst(adaptor.getGamma().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
Value constZero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

Value exp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
Value expTimesAlpha = tosaBuilder.mul(exp, alpha);
Expand Down Expand Up @@ -621,15 +648,19 @@ class ONNXThresholdedReluOpLoweringToTOSA
rewriter, outputType.getElementType(), op))) {
return failure();
}
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXThresholdedReluOp: Rank required to create splatted const");
}

Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
auto alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
auto zero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

auto greater = tosaBuilder.greater(input, alpha);
auto select = tosaBuilder.select(greater, input, zero);
Expand Down
Loading

0 comments on commit 64df722

Please sign in to comment.