Skip to content

Commit

Permalink
Move qlinear before concat to allow output fusion (#3782)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Feb 5, 2025
1 parent 8bf1a49 commit 3aee3a3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or
op.attributes().contains("pointwise");
(op.attributes().contains("pointwise") and op.name() != "quantizelinear");
}

static bool is_valid_concat(std::vector<instruction_ref> ins, size_t axis)
Expand Down
53 changes: 52 additions & 1 deletion src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -34,6 +34,7 @@
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/fp8_types.hpp>
#include <migraphx/match/dq_helpers.hpp>
Expand Down Expand Up @@ -348,6 +349,54 @@ struct match_qlinear_reused
}
};

struct match_concat_qlinear
{
auto matcher() const
{
auto any_pointwise_input = match::any_of[match::inputs()](match::pointwise());
return match::name("quantizelinear")(match::arg(0)(
match::name("concat")(match::used_once(), any_pointwise_input).bind("cat")));
}
auto get_slices(instruction_ref cat_ins) const
{
std::vector<std::vector<std::pair<std::string, value>>> slices;
auto axis = any_cast<op::concat>(cat_ins->get_operator()).axis;
size_t start = 0;
for(auto cat_inp : cat_ins->inputs())
{
auto end = start + cat_inp->get_shape().lens()[axis];
slices.push_back({{"axes", {axis}}, {"starts", {start}}, {"ends", {end}}});
start = end;
}
return slices;
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto cat_ins = r.instructions["cat"];

assert(ins->inputs().size() == 3);
auto scale = ins->inputs()[1];
auto zp = ins->inputs()[2];

auto slices = get_slices(cat_ins);
std::vector<instruction_ref> new_cat_inputs;
std::transform(
cat_ins->inputs().begin(),
cat_ins->inputs().end(),
slices.begin(),
std::back_inserter(new_cat_inputs),
[&](auto i, auto slc) {
auto scale_slc = m.insert_instruction(ins, make_op("slice", slc), {scale});
auto zp_slc = m.insert_instruction(ins, make_op("slice", slc), {zp});
return m.insert_instruction(ins, ins->get_operator(), {i, scale_slc, zp_slc});
});

m.replace_instruction(ins, cat_ins->get_operator(), new_cat_inputs);
}
};

bool is_same_value(instruction_ref a, instruction_ref b)
{
if(a == b)
Expand Down Expand Up @@ -456,6 +505,8 @@ void simplify_qdq::apply(module& m) const
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}

Expand Down
101 changes: 100 additions & 1 deletion test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -1539,4 +1539,103 @@ TEST_CASE(int4_simplify_qdq_pass_test)
EXPECT(migraphx::contains(res_1.instructions, "q"));
}

TEST_CASE(pointwise_concat_quant_per_tensor)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}};

migraphx::module m1;
{
auto i1 = m1.add_parameter("i1", s1);
auto i2 = m1.add_parameter("i2", s2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});

auto relu = m1.add_instruction(migraphx::make_op("relu"), i2);
auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu);
auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero);
m1.add_return({q});
}

migraphx::module m2;
{
std::vector<std::size_t> cat_lens{1, 6, 28, 28};
auto i1 = m2.add_parameter("i1", s1);
auto i2 = m2.add_parameter("i2", s2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});

auto relu = m2.add_instruction(migraphx::make_op("relu"), i2);
auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1);
auto zero_mb = broadcast_shift(m2, zero, cat_lens);

auto sc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb);
auto zp1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb);
auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1);

auto sc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb);
auto zp2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb);
auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2);

auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2);
m2.add_return({cat});
}
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(pointwise_concat_quant_per_channel)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}};
migraphx::shape s3{migraphx::shape::float_type, {6}};

migraphx::module m1;
{
auto i1 = m1.add_parameter("i1", s1);
auto i2 = m1.add_parameter("i2", s2);
auto scale = m1.add_literal(migraphx::generate_literal(s3, 0));
auto zero = m1.add_literal(std::int8_t{0});

auto relu = m1.add_instruction(migraphx::make_op("relu"), i2);
auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu);
auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero);
m1.add_return({q});
}

migraphx::module m2;
{
std::vector<std::size_t> cat_lens{1, 6, 28, 28};
auto i1 = m2.add_parameter("i1", s1);
auto i2 = m2.add_parameter("i2", s2);
auto scale = m2.add_literal(migraphx::generate_literal(s3, 0));
auto zero = m2.add_literal(std::int8_t{0});

auto relu = m2.add_instruction(migraphx::make_op("relu"), i2);
auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1);
auto zero_mb = broadcast_shift(m2, zero, cat_lens);

auto sc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb);
auto zp1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb);
auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1);

auto sc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb);
auto zp2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb);
auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2);

auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2);
m2.add_return({cat});
}
run_pass(m1);
EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 3aee3a3

Please sign in to comment.