Skip to content

Commit

Permalink
mul add transpose dot matcher (#2809)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Jun 14, 2024
1 parent eb0008a commit a05069e
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,36 @@ struct find_reshape_dot
}
};

// Remove transposes and converts between mul/add -> dot so simplify_algebra can perform
// const folding simplifications
struct find_mul_add_shape_op_dot
{
auto matcher() const
{
auto shape_ops = match::name("transpose", "convert");
auto const_mul_add = match::name("mul", "add")(match::either_arg(0, 1)(
match::is_constant().bind("const"), match::any().bind("input")));
auto match_shape_op = shape_ops(match::args(const_mul_add.bind("pw")));
auto skip_shape_op_outputs = match::skip_output(match::any_of(shape_ops));
return match_shape_op(skip_shape_op_outputs(match::name("dot")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto shape_ins = r.result;
auto pw = r.instructions["pw"];
auto constant = r.instructions["const"];
auto input = r.instructions["input"];

auto shape_op = shape_ins->get_operator();
auto pw_op = pw->get_operator();
auto new_inp = m.insert_instruction(shape_ins, shape_op, input);
auto new_const = m.insert_instruction(shape_ins, shape_op, constant);

m.replace_instruction(shape_ins, pw_op, new_inp, new_const);
}
};

void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < depth; i++)
Expand All @@ -1142,6 +1172,7 @@ void simplify_reshapes::apply(module& m) const
find_slice_transpose{},
find_unary_shape_transforms{},
find_reshape_dot{},
find_mul_add_shape_op_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
}
Expand Down
92 changes: 92 additions & 0 deletions test/optimize_module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
*/

#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/module.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>

Expand Down Expand Up @@ -98,4 +102,92 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
EXPECT(m1 == m2);
}

TEST_CASE(mul_add_transpose_dot)
{
auto lit1 = migraphx::generate_literal({migraphx::shape::float_type, {64}}, 0);
auto lit2 = migraphx::generate_literal({migraphx::shape::float_type, {64}}, 1);
auto lit3 = migraphx::generate_literal({migraphx::shape::float_type, {64, 64}}, 2);
migraphx::module m1;
{
auto in1 = m1.add_parameter("x", {migraphx::shape::float_type, {2, 64, 4, 4}});
auto lit1_ins = m1.add_literal(lit1);
auto lit1_unsq =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2, 3}}}), lit1_ins);
auto lit1_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 4, 4}}}), lit1_unsq);
auto mul = m1.add_instruction(migraphx::make_op("mul"), lit1_mb, in1);

auto lit2_ins = m1.add_literal(lit2);
auto lit2_unsq =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2, 3}}}), lit2_ins);
auto lit2_tp = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), lit2_unsq);
auto lit2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 4, 64}}}), lit2_tp);

auto mul_tp = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), mul);
auto add = m1.add_instruction(migraphx::make_op("add"), mul_tp, lit2_mb);

auto lit3_ins = m1.add_literal(lit3);
auto lit3_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 64, 64}}}), lit3_ins);
auto dot = m1.add_instruction(migraphx::make_op("dot"), add, lit3_mb);

m1.add_return({dot});
}
run_pass(m1);

// Compute const propagated literals
migraphx::literal lit13;
migraphx::literal lit23;
migraphx::module lit_mod;
{
auto lit1_ins = lit_mod.add_literal(lit1);
auto lit1_unsq = lit_mod.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", {0, 2, 3}}}), lit1_ins);
auto lit1_tp = lit_mod.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), lit1_unsq);
auto lit1_mb = lit_mod.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 64, 64}}}), lit1_tp);

auto lit3_ins = lit_mod.add_literal(lit3);
auto lit3_mb = lit_mod.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 64, 64}}}), lit3_ins);

auto mul_lit = lit_mod.add_instruction(migraphx::make_op("mul"), lit1_mb, lit3_mb);
auto lit13_arg = mul_lit->eval();
lit13 = migraphx::literal(lit13_arg.get_shape(), lit13_arg.data());

auto lit2_ins = lit_mod.add_literal(lit2);
auto lit2_unsq =
lit_mod.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), lit2_ins);
auto lit2_mb = lit_mod.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 64}}}), lit2_unsq);

auto dot_lit = lit_mod.add_instruction(migraphx::make_op("dot"), lit2_mb, lit3_ins);
auto lit23_arg = dot_lit->eval();
lit23 = migraphx::literal(lit23_arg.get_shape(), lit23_arg.data());
}

migraphx::module m2;
{
auto in1 = m2.add_parameter("x", {migraphx::shape::float_type, {2, 64, 4, 4}});
auto in_tp = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), in1);

auto lit13_ins = m2.add_literal(lit13);
auto dot = m2.add_instruction(migraphx::make_op("dot"), in_tp, lit13_ins);

auto lit23_ins = m2.add_literal(lit23);
auto lit23_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 4, 64}}}), lit23_ins);

auto add = m2.add_instruction(migraphx::make_op("add"), dot, lit23_mb);
m2.add_return({add});
}

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
70 changes: 70 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2432,4 +2432,74 @@ TEST_CASE(reshape_dot_broadcast_2)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_transpose)
{
migraphx::shape s{migraphx::shape::float_type, {2, 32, 64, 64}};
migraphx::shape s2{migraphx::shape::float_type, {2, 64, 32, 32}};
migraphx::module m1;
{
auto inp = m1.add_parameter("input", s);
auto c1 = m1.add_literal(migraphx::generate_literal(s));
auto mul = m1.add_instruction(migraphx::make_op("mul"), inp, c1);
auto trans = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), mul);

auto c3 = m1.add_literal(migraphx::generate_literal(s2));
auto dot = m1.add_instruction(migraphx::make_op("dot"), trans, c3);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("input", s);
auto inp_trans = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), inp);
auto c1 = m2.add_literal(migraphx::generate_literal(s));
auto c1_trans =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), c1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), inp_trans, c1_trans);
auto c3 = m2.add_literal(migraphx::generate_literal(s2));
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, c3);
m2.add_return({dot});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(add_transpose)
{
migraphx::shape s{migraphx::shape::float_type, {2, 32, 64, 64}};
migraphx::shape s2{migraphx::shape::float_type, {2, 64, 32, 32}};
migraphx::module m1;
{
auto inp = m1.add_parameter("input", s);
auto c1 = m1.add_literal(migraphx::generate_literal(s));
auto mul = m1.add_instruction(migraphx::make_op("add"), inp, c1);
auto trans = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), mul);

auto c3 = m1.add_literal(migraphx::generate_literal(s2));
auto dot = m1.add_instruction(migraphx::make_op("dot"), trans, c3);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("input", s);
auto inp_trans = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), inp);
auto c1 = m2.add_literal(migraphx::generate_literal(s));
auto c1_trans =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), c1);
auto mul = m2.add_instruction(migraphx::make_op("add"), inp_trans, c1_trans);
auto c3 = m2.add_literal(migraphx::generate_literal(s2));
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, c3);
m2.add_return({dot});
};

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit a05069e

Please sign in to comment.