diff --git a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp index 51b43197587d5c..2d447ac41eb081 100644 --- a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp @@ -11,6 +11,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API EliminateConcat; +class TRANSFORMATIONS_API EliminateConcatSplit; class TRANSFORMATIONS_API EliminateConvert; class TRANSFORMATIONS_API EliminateConvertNonZero; class TRANSFORMATIONS_API EliminateEltwise; @@ -83,6 +84,16 @@ class ov::pass::EliminateConcat : public ov::pass::MatcherPass { EliminateConcat(); }; +/** + * @ingroup ov_transformation_common_api + * @brief EliminateConcatSplit eliminates split from concat that no need + */ +class ov::pass::EliminateConcatSplit : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("EliminateConcatSplit"); + EliminateConcatSplit(); +}; + /** * @ingroup ov_transformation_common_api * @brief EliminateSplit eliminates split that does nothing diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 093275a5cb9094..254246c4f8240f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "compare.hpp" #include "itt.hpp" @@ -512,6 +513,174 @@ pass::EliminateConcat::EliminateConcat() { this->register_matcher(m, callback); } +pass::EliminateConcatSplit::EliminateConcatSplit() { + MATCHER_SCOPE(EliminateConcatSplit); + auto pattern_concat = pattern::wrap_type(); + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_map(); + const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); + if (!concat || concat->is_dynamic() || concat->get_users().size() == 1) { + return false; + } + const auto concat_users = concat->get_users(); + auto concat_inputs = concat->inputs(); + auto concat_axis = concat->get_axis(); + if (concat_axis < 0) + concat_axis = concat_axis + concat->get_output_shape(0).size(); + + bool need_opt = false; + std::vector, int64_t, int64_t>> slice_out_index_in_cocat; + for (const auto& user : concat_users) { + if (ov::is_type(user)) { + auto strided_slice_node = ov::as_type_ptr(user); + if (!strided_slice_node) { + return false; + } + // check that all values of the mask is equal 0 + auto check_mask = [](const std::vector& mask_to_check) { + auto it = std::find_if(mask_to_check.begin(), mask_to_check.end(), [](const int64_t& value) { + return value != 0; + }); + if (mask_to_check.empty() || it == mask_to_check.end()) { + return true; + } + return false; + }; + // check that we won't do change dimention rank + if (!check_mask(strided_slice_node->get_shrink_axis_mask()) || + !check_mask(strided_slice_node->get_new_axis_mask()) || + !check_mask(strided_slice_node->get_ellipsis_mask())) { + return false; + } + + auto begin_node = strided_slice_node->get_input_node_shared_ptr(1); + const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); + auto begin_values = begin_constant_node->cast_vector(); + + auto end_node = strided_slice_node->get_input_node_shared_ptr(2); + const auto& end_constant_node = ov::util::get_constant_from_source(end_node); + auto end_values = end_constant_node->cast_vector(); + + need_opt = true; + slice_out_index_in_cocat.push_back( + std::make_tuple(strided_slice_node, begin_values[concat_axis], end_values[concat_axis] - 1)); + } else { + need_opt = false; + break; + } + } + if (!need_opt || (slice_out_index_in_cocat.size() == 1)) + return false; + + uint64_t start_index = 0; + std::vector, int64_t, int64_t>> in_index_in_concat; + for (auto& concat_in : concat_inputs) { + auto tmp_index = start_index + concat_in.get_shape()[concat_axis] - 1; + in_index_in_concat.push_back( + std::make_tuple(concat_in.get_source_output().get_node_shared_ptr(), start_index, tmp_index)); + start_index = tmp_index + 1; + } + + std::vector, int64_t, int64_t>> mismatch_slices{}; + for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_cocat) { + bool matched = false; + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if (slice_begin == concat_input_begin && slice_end == concat_input_end) { + auto slice_outputs = slice_node->outputs(); + for (auto& slice_output : slice_outputs) { + replace_output_update_name(slice_output, concat_input_node); + } + matched = true; + break; + } + } + if (!matched) + mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end)); + } + + int64_t new_start_value{std::numeric_limits::max()}; + int64_t new_end_value{0}; + for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if ((concat_input_begin <= slice_begin) && (concat_input_end > slice_begin)) { + if (concat_input_begin < new_start_value) + new_start_value = concat_input_begin; + if (concat_input_end > new_end_value) + new_end_value = concat_input_end; + } + if ((concat_input_begin < slice_end) && (concat_input_end >= slice_end)) { + if (concat_input_begin < new_start_value) + new_start_value = concat_input_begin; + if (concat_input_end > new_end_value) + new_end_value = concat_input_end; + } + } + } + + std::vector> new_concat_in_nodes{}; + bool new_need = false; + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if (concat_input_begin == new_start_value) { + new_need = true; + } + if (concat_input_end == new_end_value) { + new_concat_in_nodes.push_back(concat_input_node); + new_need = false; + } + if (new_need) { + new_concat_in_nodes.push_back(concat_input_node); + } + } + + auto new_concat_node = concat->clone_with_new_inputs(ov::as_output_vector(new_concat_in_nodes)); + replace_output_update_name(concat, new_concat_node); + + for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { + if (slice_node->get_users().size() == 1 && ov::is_type(slice_node->get_users()[0]) && + ov::as_type_ptr(slice_node->get_users()[0])->get_axis() == concat_axis) { + auto next_concat = ov::as_type_ptr(slice_node->get_users()[0]); + auto next_concat_inputs = next_concat->input_values(); + std::vector> new_next_concat_inputs{}; + for (const auto& t : next_concat_inputs) { + if (t.get_node_shared_ptr() == slice_node) { + for (const auto& need_insert : new_concat_in_nodes) { + new_next_concat_inputs.push_back(need_insert); + } + continue; + } + new_next_concat_inputs.push_back(t.get_node_shared_ptr()); + } + auto new_next_concat_node = + next_concat->clone_with_new_inputs(ov::as_output_vector(new_next_concat_inputs)); + replace_output_update_name(next_concat, new_next_concat_node); + } else { + std::vector> new_slice_in_nodes{}; + new_slice_in_nodes.push_back(new_concat_node); + + auto begin_node = slice_node->get_input_node_shared_ptr(1); + const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); + auto begin_values = begin_constant_node->cast_vector(); + begin_values[concat_axis] = slice_begin - new_start_value; + new_slice_in_nodes.push_back( + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{begin_values.size()}, begin_values)); + + auto end_node = slice_node->get_input_node_shared_ptr(2); + const auto& end_constant_node = ov::util::get_constant_from_source(end_node); + auto end_values = end_constant_node->cast_vector(); + end_values[concat_axis] = slice_end - new_start_value + 1; + new_slice_in_nodes.push_back( + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{end_values.size()}, end_values)); + auto new_slice_node = slice_node->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); + replace_output_update_name(slice_node, new_slice_node); + } + } + return true; + }; + + auto m = make_shared(pattern_concat, matcher_name); + this->register_matcher(m, callback); +} + pass::EliminateSplit::EliminateSplit() { MATCHER_SCOPE(EliminateSplit); auto convert_pattern = pattern::wrap_type(); @@ -1160,6 +1329,7 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { ADD_MATCHER_FOR_THIS(EliminateSplitConcat) ADD_MATCHER_FOR_THIS(EliminateStridedSlice) ADD_MATCHER_FOR_THIS(EliminateSlice) + ADD_MATCHER_FOR_THIS(EliminateConcatSplit) // shape-dependent transformations if (use_shape_for_elimination) { diff --git a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp index 7da5c79981d415..f1fe20d85eabaf 100644 --- a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp @@ -1537,6 +1537,110 @@ TEST_F(TransformationTestsF, EliminateSliceBeforeGatherElements) { } } +TEST_F(TransformationTestsF, EliminateConcatSplit) { + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto param4 = make_shared(element::f32, Shape{2, 10, 6}); + auto concat = make_shared(ov::as_output_vector({param1, param2, param3, param4}), axis); + + auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto strided_slice1 = std::make_shared(concat, + begin_const1, + end_const1, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + + auto add_param = make_shared(element::f32, Shape{2, 10, 3}); + auto add = std::make_shared(strided_slice1, add_param); + auto result1 = std::make_shared(add); + + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 18}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice2); + auto result2 = std::make_shared(relu); + + model = std::make_shared(ResultVector{result1, result2}, + ParameterVector{param1, param2, param3, param4, add_param}); + manager.register_pass(); + } + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto param4 = make_shared(element::f32, Shape{2, 10, 6}); + auto add_param = make_shared(element::f32, Shape{2, 10, 3}); + auto add = std::make_shared(param1, add_param); + auto result1 = std::make_shared(add); + + auto concat = make_shared(ov::as_output_vector({param2, param3, param4}), axis); + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 15}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice2); + auto result2 = std::make_shared(relu); + + model_ref = std::make_shared(ResultVector{result1, result2}, + ParameterVector{param1, param2, param3, param4, add_param}); + } +} + +TEST_F(TransformationTestsF, EliminateConcatSplitConcat) { + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto concat = make_shared(ov::as_output_vector({param1, param2, param3}), axis); + + auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto strided_slice1 = std::make_shared(concat, + begin_const1, + end_const1, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice1); + + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 12}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto concat1 = make_shared(ov::as_output_vector({relu, strided_slice2}), axis); + + auto result = std::make_shared(concat1); + model = std::make_shared(ResultVector{result}, ParameterVector{param1, param2, param3}); + manager.register_pass(); + } + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + + auto relu = std::make_shared(param1); + auto concat = make_shared(ov::as_output_vector({relu, param2, param3}), axis); + auto result = std::make_shared(concat); + + model_ref = std::make_shared(ResultVector{result}, ParameterVector{param1, param2, param3}); + } +} TEST_F(TransformationTestsF, EliminateStridedSlice) { { auto input = std::make_shared(ov::element::f32, diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index 185d3380f42859..09d773c462dd03 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -172,7 +172,19 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector& mask) std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs); - check_new_args_count(this, new_args); + auto args_size = new_args.size(); + NODE_VALIDATION_CHECK(this, (args_size == 3) || (args_size == 4), "Incorrect number of new inputs: ", args_size); + + if (args_size == 3) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + m_begin_mask, + m_end_mask, + m_new_axis_mask, + m_shrink_axis_mask, + m_ellipsis_mask); + } return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2),