Skip to content

Commit

Permalink
use reshape to handle flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Nov 1, 2024
1 parent 45b59bc commit 237916e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
12 changes: 8 additions & 4 deletions src/include/migraphx/op/flatten.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ struct flatten
}
else
{
check_shapes{inputs, *this}.standard();
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
Expand All @@ -91,9 +90,14 @@ struct flatten
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};

visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
};

} // namespace op
Expand Down
14 changes: 14 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,13 +1107,27 @@ struct find_mul_add_shape_op_dot
}
};

struct find_flatten
{
auto matcher() const { return match::name("flatten"); }

void apply(module& m, const match::matcher_result& r) const
{
auto flatten = r.result;
m.replace_instruction(flatten,
make_op("reshape", {{"dims", flatten->get_shape().lens()}}),
flatten->inputs());
}
};

void simplify_reshapes::apply(module& m) const
{
m.repeat_while_changes(depth, [&] {
match::find_matches(m,
find_where_op{},
find_resize{},
find_nop_reshapes{},
find_flatten{},
find_reshape_cont{},
find_nested_shape_transforms{},
find_concat_slice{},
Expand Down
2 changes: 1 addition & 1 deletion test/eliminate_contiguous_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ TEST_CASE(non_standard_flatten_op)
m.add_instruction(migraphx::make_op("flatten"), c);
auto count = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(std::distance(m.begin(), m.end()) == count);
EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
}

TEST_CASE(standard_flatten_op)
Expand Down
22 changes: 22 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2630,4 +2630,26 @@ TEST_CASE(add_transpose)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(flatten)
{
migraphx::shape s{migraphx::shape::float_type, {4608, 8, 2}};

migraphx::module m1;
{
auto inp = m1.add_parameter("input", s);
auto flat = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), inp);
m1.add_return({flat});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("input", s);
auto flat = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4608, 16}}}), inp);
m2.add_return({flat});
};

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 237916e

Please sign in to comment.