diff --git a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp index 054b975c2d45..5d35d4db33b2 100644 --- a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp +++ b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp @@ -837,8 +837,11 @@ struct SoftmaxConverter final }; } // namespace -// MIGraphX pseudo code: -// output[i] = static_cast(input[i] - zero_pts[i]) * scales[i]; +// MIGraphX implements: +// Let T = scale element type +// output[i] = (convert(input[i]) - convert(zero_pts[i])) * scales[i]; +// For f32, this matches ONNX reference, dequantizing to f16, if it's ever done +// will be less precise than the reference but that's probably fine. LogicalResult DeQuantizeLinearConverter::matchAndRewrite( migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -847,34 +850,25 @@ LogicalResult DeQuantizeLinearConverter::matchAndRewrite( Value output = op.getOutput(); Location loc = op->getLoc(); - Value shifted = input; + Type outputType = getShapedElementTy(output); + Value upcastInput = createCastOp(rewriter, loc, outputType, input); + + Value shifted = upcastInput; if (auto bias = adaptor.getBias()) { - Type inElemTy = getShapedElementTy(input); - Type biasElemTy = getShapedElementTy(bias); - Type elementType = - inElemTy.getIntOrFloatBitWidth() <= biasElemTy.getIntOrFloatBitWidth() - ? biasElemTy - : inElemTy; - if (inElemTy != elementType) - input = createCastOp(rewriter, loc, elementType, shifted); - if (biasElemTy != elementType) - bias = createCastOp(rewriter, loc, elementType, bias); - shifted = - createOpAndInfer(rewriter, loc, elementType, input, bias); + Value upcastBias = createCastOp(rewriter, loc, outputType, bias); + shifted = createOpAndInfer(rewriter, loc, outputType, + upcastInput, upcastBias); } - Type outputType = getShapedElementTy(output); - Value upCast = createCastOp(rewriter, loc, outputType, shifted); - Value scaled = createOpAndInfer(rewriter, loc, outputType, - upCast, scale, /*shift=*/0); + shifted, scale, /*shift=*/0); rewriter.replaceOp(op, scaled); return success(); } // MIGraphX pseudo code: -// int64_t quantized = static_cast( +// int32_t quantized = static_cast( // std::round(input[i] / scales[i])) + zero_pts[i]; // output[i] = std::max(-128, std::min(127, quantized)); LogicalResult QuantizeLinearConverter::matchAndRewrite( diff --git a/mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir b/mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir index 8ad631ff7be5..80ee52996d64 100644 --- a/mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir +++ b/mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir @@ -28,8 +28,9 @@ module { } // CHECK-LABEL: func @dequantize_scale_bias - // CHECK: tosa.sub // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.sub // CHECK: tosa.mul func.func @dequantize_scale_bias(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} { %1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1> @@ -37,9 +38,9 @@ module { } // CHECK-LABEL: func @dequantize_wide_bias - // CHECK: tosa.cast{{.*}}i32 - // CHECK: tosa.sub{{.*}}i32 // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.sub{{.*}}f32 // CHECK: tosa.mul func.func @dequantize_wide_bias(%arg: !migraphx.shaped<1x112x112x64xi8, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} { %1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi8, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1> @@ -47,9 +48,9 @@ module { } // CHECK-LABEL: func @dequantize_wide_input - // CHECK: tosa.cast{{.*}}i32 - // CHECK: tosa.sub{{.*}}i32 // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.cast{{.*}}f32 + // CHECK: tosa.sub{{.*}}f32 // CHECK: tosa.mul func.func @dequantize_wide_input(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi8, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} { %1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi8, 1> -> <1x112x112x64xf32, 802816x7168x64x1> @@ -142,8 +143,9 @@ module { // CHECK-LABEL: func @conv_with_quant // CHECK: tosa.conv2d{{.*}} quantization_info - // CHECK: tosa.sub // CHECK: tosa.cast + // CHECK: tosa.cast + // CHECK: tosa.sub // CHECK: tosa.mul // CHECK: tosa.reciprocal // CHECK: tosa.mul