Skip to content

Commit

Permalink
feat: Also canonicalize select and geq with different int types
Browse files Browse the repository at this point in the history
  • Loading branch information
roberteg16 committed Jan 31, 2025
1 parent d20ac95 commit c3e6e1f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
54 changes: 47 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,43 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
return rewriter.notifyMatchFailure(
op, "RHS of predicate GreaterEqualOp is not a constant");
}

auto isCompatibleSplat = [](DenseElementsAttr a,
DenseElementsAttr b) -> bool {
if (!a.isSplat() || !b.isSplat()) {
return false;
}
if (llvm::isa<IntegerType>(a.getElementType())) {
return a.getSplatValue<APInt>() == b.getSplatValue<APInt>();

auto aAsIntegerType = dyn_cast<IntegerType>(a.getElementType());
auto bAsIntegerType = dyn_cast<IntegerType>(b.getElementType());
if (aAsIntegerType && bAsIntegerType) {
if (aAsIntegerType.getSignedness() != bAsIntegerType.getSignedness()) {
return false;
}

auto aAsAPInt = a.getSplatValue<APInt>();
auto bAsAPInt = b.getSplatValue<APInt>();

const size_t aBitWith = aAsAPInt.getBitWidth();
const size_t bBitWith = bAsAPInt.getBitWidth();

if (aBitWith >= bBitWith) {
return aAsAPInt == (bAsIntegerType.isUnsigned()
? bAsAPInt.zext(aBitWith)
: bAsAPInt.sext(aBitWith));
}
return (aAsIntegerType.isUnsigned()
? aAsAPInt.zext(bBitWith)
: aAsAPInt.sext(bBitWith)) == bAsAPInt;
}
if (llvm::isa<FloatType>(a.getElementType())) {
return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();

auto aAsFloatType = dyn_cast<FloatType>(a.getElementType());
auto bAsFloatType = dyn_cast<FloatType>(b.getElementType());
if (!aAsFloatType || aAsFloatType != bAsFloatType) {
return false;
}
return false; // Only int and float types are supported

return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();
};

auto onFalse = op.getOnFalse();
Expand Down Expand Up @@ -237,10 +262,25 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue);
}
}

Value input = geq.getInput1();

// In case they do not have same bit width, insert a cast to still be able
// to do this canonicalization
const size_t geqBitWidth =
geq.getInput1().getType().getElementTypeBitWidth();
const size_t selectBitWidth = op.getType().getElementTypeBitWidth();
if (geqBitWidth != selectBitWidth) {
input = rewriter.create<tosa::CastOp>(
op->getLoc(),
geq.getInput1().getType().clone(op.getType().getElementType()),
input);
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), geq.getInput1(),
rewriter.getI64IntegerAttr(clampIntMin),
op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin),
rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax);

return success();
}
};
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,59 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) ->
return %3 : tensor<13x21x3xf32>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat1
func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
return %3 : tensor<13x21x3xi8>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat2
func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %1, %arg1 : ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
return %3 : tensor<13x21x3xi8>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat1
func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
return %3 : tensor<13x21x3xi64>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat2
func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
return %3 : tensor<13x21x3xi64>
}

0 comments on commit c3e6e1f

Please sign in to comment.