diff --git a/src/eliminate_contiguous.cpp b/src/eliminate_contiguous.cpp index cf9e0a6e344..217103e5489 100644 --- a/src/eliminate_contiguous.cpp +++ b/src/eliminate_contiguous.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace migraphx { @@ -71,6 +72,8 @@ static bool try_compute_shape(instruction_ref ins, void eliminate_contiguous::apply(module& m) const { + std::vector const_instruction; + for(auto ins : iterator_for(m)) { // return instruction should have inputs with standard shape @@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& m) const auto args = ins->inputs(); auto new_args = args; auto mod_args = ins->module_inputs(); + for(auto arg : ins->inputs()) { if(arg->name() == op_name) @@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& m) const } else if(prev->can_eval()) { - auto c = op::contiguous{}; - auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); - - auto l = m.add_literal(r.get_shape(), r.data()); - m.replace_instruction(arg, l); + const_instruction.push_back(arg); } } } } + + // Perform evaluations in parallel + std::vector literals(const_instruction.size()); + par_for(const_instruction.size(), 1, [&](const auto i) { + auto c = op::contiguous{}; + auto prev = const_instruction[i]->inputs().front(); + literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); + }); + + for(size_t i = 0; i < const_instruction.size(); i++) + { + auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); + m.replace_instruction(const_instruction[i], l); + } } } // namespace MIGRAPHX_INLINE_NS