From 05a7707b5cddbf8d02d7fe647b712a748eb78ca8 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Tue, 19 Mar 2024 11:54:03 -0400 Subject: [PATCH] Fixes to parse DynamicQuantizeLinear (#2896) (#2905) --- src/onnx/parse_dynamicquantizelinear.cpp | 44 +++++++++---------- .../parse/dynamicquantizelinear_2d_test.cpp | 34 ++++++++------ 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/onnx/parse_dynamicquantizelinear.cpp b/src/onnx/parse_dynamicquantizelinear.cpp index d945600f02f..8041d09eb8c 100644 --- a/src/onnx/parse_dynamicquantizelinear.cpp +++ b/src/onnx/parse_dynamicquantizelinear.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -98,30 +98,31 @@ struct parse_dynamicquantizelinear : op_parser if(x_shape.dynamic()) MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported"); - auto x_reshaped = - (x_shape.lens().size() == 1) - ? x - : info.add_instruction( - migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x); - auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}}); - x_reshaped = - info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0); // 1. Computing y_scale // Note: currently, DynamicQuantizeLinear only has uint8 quantization: - const auto x_max = std::numeric_limits::max(); - const auto x_min = std::numeric_limits::min(); - - auto q_range = - info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max - x_min}}); + const auto type_max = std::numeric_limits::max(); + const auto type_min = std::numeric_limits::min(); + std::vector axes(x_shape.lens().size()); + std::iota(axes.begin(), axes.end(), 0); // maximum(0, max(x)) - auto max_x = - info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped); + auto reduce_max_x = + info.add_instruction(migraphx::make_op("reduce_max", {{"axes", axes}}), x); + auto max_x = info.add_common_op("max", lit_0, reduce_max_x); + // minimum(0, min(x)) - auto min_x = - info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped); + auto reduce_min_x = + info.add_instruction(migraphx::make_op("reduce_min", {{"axes", axes}}), x); + auto min_x = info.add_common_op("min", lit_0, reduce_min_x); + + auto q_range = info.add_literal(migraphx::literal{ + migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max - type_min}}); + auto q_min = info.add_literal( + migraphx::literal{migraphx::shape{x_type, max_x->get_shape().lens()}, {type_min}}); + auto q_max = info.add_literal( + migraphx::literal{migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max}}); // y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin) auto sub0 = info.add_common_op("sub", max_x, min_x); @@ -129,10 +130,9 @@ struct parse_dynamicquantizelinear : op_parser // 2. Computing y_zero_point // intermediate_zero_point = qmin - min(x) / y_scale - auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_min}}); - auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max}}); - auto sub1 = info.add_common_op("sub", q_min, min_x); - auto interm_zp = info.add_common_op("div", sub1, y_scale); + auto div1 = info.add_common_op("div", min_x, y_scale); + auto interm_zp = info.add_common_op("sub", q_min, div1); + // y_zero_point = cast(round(saturate(itermediate_zero_point))) auto saturate = info.add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max); auto round = info.add_instruction(migraphx::make_op("nearbyint"), saturate); diff --git a/test/onnx/parse/dynamicquantizelinear_2d_test.cpp b/test/onnx/parse/dynamicquantizelinear_2d_test.cpp index 38da4238c4b..7468d1038a3 100644 --- a/test/onnx/parse/dynamicquantizelinear_2d_test.cpp +++ b/test/onnx/parse/dynamicquantizelinear_2d_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -32,25 +32,31 @@ TEST_CASE(dynamicquantizelinear_2d_test) auto x_type = migraphx::shape::float_type; auto x = mm->add_parameter("x", {x_type, x_dims}); - auto l0 = mm->add_literal({0.f}); - auto x_reshaped = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), x); - x_reshaped = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, l0); + auto l0 = mm->add_literal({0.f}); - auto q_range = mm->add_literal( - migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits::max()}}); + std::vector axes(x->get_shape().lens().size()); + std::iota(axes.begin(), axes.end(), 0); - auto max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped); - auto min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped); + auto reduce_max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", axes}}), x); + auto max_x = add_common_op(*mm, migraphx::make_op("max"), {l0, reduce_max_x}); + + auto reduce_min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", axes}}), x); + auto min_x = add_common_op(*mm, migraphx::make_op("min"), {l0, reduce_min_x}); + + auto q_range = mm->add_literal(migraphx::literal{ + migraphx::shape{x_type, max_x->get_shape().lens()}, + {std::numeric_limits::max() - std::numeric_limits::min()}}); + auto q_min = mm->add_literal(migraphx::literal{ + migraphx::shape{x_type, max_x->get_shape().lens()}, {std::numeric_limits::min()}}); + auto q_max = mm->add_literal(migraphx::literal{ + migraphx::shape{x_type, min_x->get_shape().lens()}, {std::numeric_limits::max()}}); auto sub0 = mm->add_instruction(migraphx::make_op("sub"), max_x, min_x); auto y_scale = mm->add_instruction(migraphx::make_op("div"), sub0, q_range); - auto q_min = mm->add_literal( - migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits::min()}}); - auto q_max = mm->add_literal( - migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits::max()}}); - auto sub1 = mm->add_instruction(migraphx::make_op("sub"), q_min, min_x); - auto interm_zp = mm->add_instruction(migraphx::make_op("div"), sub1, y_scale); + auto div1 = add_common_op(*mm, migraphx::make_op("div"), {min_x, y_scale}); + auto interm_zp = add_common_op(*mm, migraphx::make_op("sub"), {q_min, div1}); + auto saturate = mm->add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max); auto round = mm->add_instruction(migraphx::make_op("nearbyint"), saturate); auto y_zero_point = mm->add_instruction(