Skip to content

Commit

Permalink
Simplify quant_dot section and tests
Browse files Browse the repository at this point in the history
Clean up uint8 handling for quant_dot. Fix tests
  • Loading branch information
Ted Themistokleous committed Nov 7, 2024
1 parent 84d850b commit be3444a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
56 changes: 26 additions & 30 deletions src/onnx/parse_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,26 @@ struct parse_matmul : op_parser<parse_matmul>
return res;
}

static void handle_uint8_input(const onnx_parser::node_info& info,
const bool has_bias,
const instruction_ref& offset_op,
instruction_ref& arg,
instruction_ref& bias_arg)
{
auto arg_type = arg->get_shape().type();
// always convert uint8 to int8 to avoid rollover
if(arg_type == migraphx::shape::uint8_type)
{
shift_input_and_bias(info, offset_op, has_bias, arg, bias_arg);
}

// subtract bias from result after conversion
if(has_bias)
{
bias_arg = info.add_common_op("sub", arg, bias_arg);
}
}

instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
Expand Down Expand Up @@ -415,37 +435,13 @@ struct parse_matmul : op_parser<parse_matmul>
MIGRAPHX_THROW(op_name + ": Unsupported type");
}

instruction_ref offset_op;
if(is_quant_dot)
if((is_quant_dot and ((a0_type == migraphx::shape::uint8_type) or
(a1_type == migraphx::shape::uint8_type))))
{
if(((a0_type == migraphx::shape::uint8_type) or
(a1_type == migraphx::shape::uint8_type)))
{
offset_op = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {-128}});
}

// always convert uint8 to int8 to avoid rollover
if((a0_type == migraphx::shape::uint8_type))
{
shift_input_and_bias(info, offset_op, has_ba0, a0, ba0);
}

if((a1_type == migraphx::shape::uint8_type))
{
shift_input_and_bias(info, offset_op, has_ba1, a1, ba1);
}

// subtract bias from result after conversion
if(has_ba0)
{
ba0 = info.add_common_op("sub", a0, ba0);
}

if(has_ba1)
{
ba1 = info.add_common_op("sub", a1, ba1);
}
auto offset_op = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {-128}});
handle_uint8_input(info, has_ba0, offset_op, a0, ba0);
handle_uint8_input(info, has_ba1, offset_op, a1, ba1);
}

broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1);
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/parse/matmulinteger_dual_zp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ TEST_CASE(matmulinteger_dual_zp_test)

auto prog = optimize_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx");

EXPECT(p == prog);
EXPECT(p.sort() == prog.sort());
}
2 changes: 1 addition & 1 deletion test/onnx/parse/matmulinteger_one_zp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ TEST_CASE(matmulinteger_one_zp_test)

auto prog = optimize_onnx("matmulinteger_int8_uint8_one_zp_test.onnx");

EXPECT(p == prog);
EXPECT(p.sort() == prog.sort());
}

0 comments on commit be3444a

Please sign in to comment.