Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move qlinear before concat to allow output fusion #3782

Merged
merged 6 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or
op.attributes().contains("pointwise");
(op.attributes().contains("pointwise") and op.name() != "quantizelinear");
}

static bool is_valid_concat(std::vector<instruction_ref> ins, size_t axis)
Expand Down
54 changes: 53 additions & 1 deletion src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -34,6 +34,7 @@
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/fp8_types.hpp>
#include <migraphx/match/dq_helpers.hpp>
Expand Down Expand Up @@ -348,6 +349,55 @@
}
};

struct match_concat_qlinear
{
auto matcher() const
{
auto any_pointwise_input = match::any_of[match::inputs()](match::pointwise());
return match::name("quantizelinear")(
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
match::arg(0)(match::name("concat")(any_pointwise_input).bind("cat")));
}
auto get_slices(instruction_ref cat_ins) const
{
std::vector<std::vector<std::pair<std::string, value>>> slices;
auto axis = any_cast<op::concat>(cat_ins->get_operator()).axis;
size_t start = 0;
size_t end;

Check warning on line 365 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: The scope of the variable 'end' can be reduced. [variableScope]
for(auto cat_inp : cat_ins->inputs())
{
end = start + cat_inp->get_shape().lens()[axis];
slices.push_back({{"axes", {axis}}, {"starts", {start}}, {"ends", {end}}});
start = end;
}
return slices;
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto cat_ins = r.instructions["cat"];

assert(ins->inputs().size() == 3);
auto scale = ins->inputs()[1];
auto zp = ins->inputs()[2];

auto slices = get_slices(cat_ins);
std::vector<instruction_ref> new_cat_inputs;
std::transform(
cat_ins->inputs().begin(),
cat_ins->inputs().end(),
slices.begin(),
std::back_inserter(new_cat_inputs),
[&](auto i, auto slc) {
auto scale_slc = m.insert_instruction(ins, make_op("slice", slc), {scale});
auto zp_slc = m.insert_instruction(ins, make_op("slice", slc), {zp});
return m.insert_instruction(ins, ins->get_operator(), {i, scale_slc, zp_slc});
});

m.replace_instruction(ins, cat_ins->get_operator(), new_cat_inputs);
}
};

bool is_same_value(instruction_ref a, instruction_ref b)
{
if(a == b)
Expand Down Expand Up @@ -456,6 +506,8 @@
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}

Expand Down
101 changes: 100 additions & 1 deletion test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -1539,4 +1539,103 @@
EXPECT(migraphx::contains(res_1.instructions, "q"));
}

TEST_CASE(pointwise_concat_quant_per_tensor)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}};
std::vector<std::size_t> cat_lens{1, 6, 28, 28};

Check warning on line 1546 in test/simplify_qdq_test.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: The scope of the variable 'cat_lens' can be reduced. [variableScope]

migraphx::module m1;
{
auto i1 = m1.add_parameter("i1", s1);
auto i2 = m1.add_parameter("i2", s2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});

auto relu = m1.add_instruction(migraphx::make_op("relu"), i2);
auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu);
auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero);
m1.add_return({q});
}

migraphx::module m2;
{
auto i1 = m2.add_parameter("i1", s1);
auto i2 = m2.add_parameter("i2", s2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});

auto relu = m2.add_instruction(migraphx::make_op("relu"), i2);
auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1);
auto zero_mb = broadcast_shift(m2, zero, cat_lens);

auto sc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb);
auto zp1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb);
auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1);

auto sc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb);
auto zp2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb);
auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2);

auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2);
m2.add_return({cat});
}
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(pointwise_concat_quant_per_channel)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}};
migraphx::shape s3{migraphx::shape::float_type, {6}};
std::vector<std::size_t> cat_lens{1, 6, 28, 28};

Check warning on line 1596 in test/simplify_qdq_test.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: The scope of the variable 'cat_lens' can be reduced. [variableScope]

migraphx::module m1;
{
auto i1 = m1.add_parameter("i1", s1);
auto i2 = m1.add_parameter("i2", s2);
auto scale = m1.add_literal(migraphx::generate_literal(s3, 0));
auto zero = m1.add_literal(std::int8_t{0});

auto relu = m1.add_instruction(migraphx::make_op("relu"), i2);
auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu);
auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero);
m1.add_return({q});
}

migraphx::module m2;
{
auto i1 = m2.add_parameter("i1", s1);
auto i2 = m2.add_parameter("i2", s2);
auto scale = m2.add_literal(migraphx::generate_literal(s3, 0));
auto zero = m2.add_literal(std::int8_t{0});

auto relu = m2.add_instruction(migraphx::make_op("relu"), i2);
auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1);
auto zero_mb = broadcast_shift(m2, zero, cat_lens);

auto sc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb);
auto zp1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb);
auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1);

auto sc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb);
auto zp2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb);
auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2);

auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2);
m2.add_return({cat});
}
run_pass(m1);
EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading