diff --git a/src/include/migraphx/output_iterator.hpp b/src/include/migraphx/output_iterator.hpp index 7aced4a08a3..e4d670b8537 100644 --- a/src/include/migraphx/output_iterator.hpp +++ b/src/include/migraphx/output_iterator.hpp @@ -72,6 +72,12 @@ auto join_back_inserter(Container& c) [&](const auto& r) { c.insert(c.end(), r.begin(), r.end()); }); } +template +auto push_inserter(Container& c) +{ + return make_function_output_iterator([&](const auto& x) { c.push(x); }); +} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx + #endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP diff --git a/src/instruction.cpp b/src/instruction.cpp index 47bea70379e..219c75e3432 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -58,22 +59,39 @@ instruction::instruction(literal l) { } +struct replace_shape_order +{ + instruction_ref start; + + std::size_t location(instruction_ref x) const { return std::distance(start, x); } + + bool operator()(instruction_ref x, instruction_ref y) const + { + return location(x) > location(y); + } +}; + void instruction::replace(const shape& r) { if(r != result) { result = r; - std::deque q(output.begin(), output.end()); + auto start = std::find_if(output.front()->inputs().begin(), + output.front()->inputs().end(), + [&](instruction_ref x) { return this == as_address(x); }); + assert(as_address(*start) == this); + std::priority_queue, replace_shape_order> q( + output.begin(), output.end(), replace_shape_order{*start}); while(not q.empty()) { - instruction_ref ins = q.front(); - q.pop_front(); + instruction_ref ins = q.top(); + q.pop(); assert(ins->name() == "@return" or ins->name().front() != '@'); shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args); if(new_r != ins->result) { ins->result = new_r; - std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q)); + std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q)); } } } diff --git a/test/instruction.cpp b/test/instruction.cpp index 134658e336b..0ee22e13553 100644 --- a/test/instruction.cpp +++ b/test/instruction.cpp @@ -67,4 +67,24 @@ TEST_CASE(check_replace_shape) EXPECT(add->get_shape() == r); } +TEST_CASE(check_replace_dag) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {3, 2}}; + auto input = m.add_parameter("x", s); + auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input); + auto abs = m.add_instruction(migraphx::make_op("abs"), reduce); + auto sin = m.add_instruction(migraphx::make_op("sin"), reduce); + auto add = m.add_instruction(migraphx::make_op("add"), abs, sin); + auto add2 = m.add_instruction(migraphx::make_op("add"), add, reduce); + + reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}})); + + migraphx::shape r{migraphx::shape::float_type, {3, 1}}; + EXPECT(reduce->get_shape() == r); + EXPECT(abs->get_shape() == r); + EXPECT(sin->get_shape() == r); + EXPECT(add->get_shape() == r); + EXPECT(add2->get_shape() == r); +} int main(int argc, const char* argv[]) { test::run(argc, argv); }