Skip to content

Commit

Permalink
Change definition of dequantizelinear to match MIGraphX, ONNX (#1567)
Browse files Browse the repository at this point in the history
While debugging accuracy issues related to an MIGraphX change that
caused the emission of ops like

    migraphx.dequantizelinear %x, %bias, %scale : <...xi8>, <...xi8>, <...xf32>

we discovered that our understanding of MIGraphX's semantics for
dequnatizelinear were incorrect or out of date. The intended
semantics, as seen in MIGraphX's quantiziation simplifier and, more
importantly, in the ONNX reference at
https://github.com/onnx/onnx/blob/c7717bb39c684a6d86c82a5bb1d4c0e5d90353fe/onnx/reference/ops/op_dequantize_linear.py#L42
are that the input and bias have to be cast to the
scale/output (they're required to match) element type before
computation.

We didn't do this, causing validitaion errors due to 8-bit integer
over/unferflow

This commit fixes is the issue. It isn't known if it needs a backport.
  • Loading branch information
krzysz00 authored and CharlieL7 committed Aug 8, 2024
1 parent 2d2ac48 commit 1324731
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
34 changes: 14 additions & 20 deletions mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,11 @@ struct SoftmaxConverter final
};
} // namespace

// MIGraphX pseudo code:
// output[i] = static_cast<T>(input[i] - zero_pts[i]) * scales[i];
// MIGraphX implements:
// Let T = scale element type
// output[i] = (convert<T>(input[i]) - convert<T>(zero_pts[i])) * scales[i];
// For f32, this matches ONNX reference, dequantizing to f16, if it's ever done
// will be less precise than the reference but that's probably fine.
LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -847,34 +850,25 @@ LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
Value output = op.getOutput();
Location loc = op->getLoc();

Value shifted = input;
Type outputType = getShapedElementTy(output);
Value upcastInput = createCastOp(rewriter, loc, outputType, input);

Value shifted = upcastInput;
if (auto bias = adaptor.getBias()) {
Type inElemTy = getShapedElementTy(input);
Type biasElemTy = getShapedElementTy(bias);
Type elementType =
inElemTy.getIntOrFloatBitWidth() <= biasElemTy.getIntOrFloatBitWidth()
? biasElemTy
: inElemTy;
if (inElemTy != elementType)
input = createCastOp(rewriter, loc, elementType, shifted);
if (biasElemTy != elementType)
bias = createCastOp(rewriter, loc, elementType, bias);
shifted =
createOpAndInfer<tosa::SubOp>(rewriter, loc, elementType, input, bias);
Value upcastBias = createCastOp(rewriter, loc, outputType, bias);
shifted = createOpAndInfer<tosa::SubOp>(rewriter, loc, outputType,
upcastInput, upcastBias);
}

Type outputType = getShapedElementTy(output);
Value upCast = createCastOp(rewriter, loc, outputType, shifted);

Value scaled = createOpAndInfer<tosa::MulOp>(rewriter, loc, outputType,
upCast, scale, /*shift=*/0);
shifted, scale, /*shift=*/0);

rewriter.replaceOp(op, scaled);
return success();
}

// MIGraphX pseudo code:
// int64_t quantized = static_cast<int32>(
// int32_t quantized = static_cast<int32>(
// std::round(input[i] / scales[i])) + zero_pts[i];
// output[i] = std::max(-128, std::min(127, quantized));
LogicalResult QuantizeLinearConverter::matchAndRewrite(
Expand Down
14 changes: 8 additions & 6 deletions mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,29 @@ module {
}

// CHECK-LABEL: func @dequantize_scale_bias
// CHECK: tosa.sub
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub
// CHECK: tosa.mul
func.func @dequantize_scale_bias(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
}

// CHECK-LABEL: func @dequantize_wide_bias
// CHECK: tosa.cast{{.*}}i32
// CHECK: tosa.sub{{.*}}i32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub{{.*}}f32
// CHECK: tosa.mul
func.func @dequantize_wide_bias(%arg: !migraphx.shaped<1x112x112x64xi8, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi8, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
}

// CHECK-LABEL: func @dequantize_wide_input
// CHECK: tosa.cast{{.*}}i32
// CHECK: tosa.sub{{.*}}i32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub{{.*}}f32
// CHECK: tosa.mul
func.func @dequantize_wide_input(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi8, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi8, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
Expand Down Expand Up @@ -142,8 +143,9 @@ module {

// CHECK-LABEL: func @conv_with_quant
// CHECK: tosa.conv2d{{.*}} quantization_info
// CHECK: tosa.sub
// CHECK: tosa.cast
// CHECK: tosa.cast
// CHECK: tosa.sub
// CHECK: tosa.mul
// CHECK: tosa.reciprocal
// CHECK: tosa.mul
Expand Down

0 comments on commit 1324731

Please sign in to comment.