From c5c9780ec0e008f8037349478d6e3bf926643146 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Tue, 29 Oct 2024 16:57:44 -0400 Subject: [PATCH] Revert "Fixed instruction::replace() logic. (#3553)" This reverts commit 58cf5997dfbc44a70b529afe4a12549365000f99. --- src/include/migraphx/output_iterator.hpp | 6 ----- src/instruction.cpp | 28 +++++------------------- test/instruction.cpp | 20 ----------------- 3 files changed, 5 insertions(+), 49 deletions(-) diff --git a/src/include/migraphx/output_iterator.hpp b/src/include/migraphx/output_iterator.hpp index e4d670b8537..7aced4a08a3 100644 --- a/src/include/migraphx/output_iterator.hpp +++ b/src/include/migraphx/output_iterator.hpp @@ -72,12 +72,6 @@ auto join_back_inserter(Container& c) [&](const auto& r) { c.insert(c.end(), r.begin(), r.end()); }); } -template -auto push_inserter(Container& c) -{ - return make_function_output_iterator([&](const auto& x) { c.push(x); }); -} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx - #endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP diff --git a/src/instruction.cpp b/src/instruction.cpp index 219c75e3432..47bea70379e 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -26,8 +26,7 @@ #include #include #include -#include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -59,39 +58,22 @@ instruction::instruction(literal l) { } -struct replace_shape_order -{ - instruction_ref start; - - std::size_t location(instruction_ref x) const { return std::distance(start, x); } - - bool operator()(instruction_ref x, instruction_ref y) const - { - return location(x) > location(y); - } -}; - void instruction::replace(const shape& r) { if(r != result) { result = r; - auto start = std::find_if(output.front()->inputs().begin(), - output.front()->inputs().end(), - [&](instruction_ref x) { return this == as_address(x); }); - assert(as_address(*start) == this); - std::priority_queue, replace_shape_order> q( - output.begin(), output.end(), replace_shape_order{*start}); + std::deque q(output.begin(), output.end()); while(not q.empty()) { - instruction_ref ins = q.top(); - q.pop(); + instruction_ref ins = q.front(); + q.pop_front(); assert(ins->name() == "@return" or ins->name().front() != '@'); shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args); if(new_r != ins->result) { ins->result = new_r; - std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q)); + std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q)); } } } diff --git a/test/instruction.cpp b/test/instruction.cpp index 0ee22e13553..134658e336b 100644 --- a/test/instruction.cpp +++ b/test/instruction.cpp @@ -67,24 +67,4 @@ TEST_CASE(check_replace_shape) EXPECT(add->get_shape() == r); } -TEST_CASE(check_replace_dag) -{ - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {3, 2}}; - auto input = m.add_parameter("x", s); - auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input); - auto abs = m.add_instruction(migraphx::make_op("abs"), reduce); - auto sin = m.add_instruction(migraphx::make_op("sin"), reduce); - auto add = m.add_instruction(migraphx::make_op("add"), abs, sin); - auto add2 = m.add_instruction(migraphx::make_op("add"), add, reduce); - - reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}})); - - migraphx::shape r{migraphx::shape::float_type, {3, 1}}; - EXPECT(reduce->get_shape() == r); - EXPECT(abs->get_shape() == r); - EXPECT(sin->get_shape() == r); - EXPECT(add->get_shape() == r); - EXPECT(add2->get_shape() == r); -} int main(int argc, const char* argv[]) { test::run(argc, argv); }