From 236a2e54a597725b6c92cb68cb3848c771945dd9 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Wed, 5 Feb 2025 14:32:21 +0800 Subject: [PATCH 1/9] [Common][Transformations] eliminated no needed split & concat Signed-off-by: Zhai, Xuejun --- .../common_optimizations/nop_elimination.hpp | 11 ++ .../common_optimizations/nop_elimination.cpp | 172 ++++++++++++++++++ .../common_optimizations/nop_elimination.cpp | 104 +++++++++++ src/core/src/op/strided_slice.cpp | 10 + 4 files changed, 297 insertions(+) diff --git a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp index 51b43197587d5c..2d447ac41eb081 100644 --- a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp @@ -11,6 +11,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API EliminateConcat; +class TRANSFORMATIONS_API EliminateConcatSplit; class TRANSFORMATIONS_API EliminateConvert; class TRANSFORMATIONS_API EliminateConvertNonZero; class TRANSFORMATIONS_API EliminateEltwise; @@ -83,6 +84,16 @@ class ov::pass::EliminateConcat : public ov::pass::MatcherPass { EliminateConcat(); }; +/** + * @ingroup ov_transformation_common_api + * @brief EliminateConcatSplit eliminates split from concat that no need + */ +class ov::pass::EliminateConcatSplit : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("EliminateConcatSplit"); + EliminateConcatSplit(); +}; + /** * @ingroup ov_transformation_common_api * @brief EliminateSplit eliminates split that does nothing diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 093275a5cb9094..a52c78d17c5cfa 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -512,6 +512,177 @@ pass::EliminateConcat::EliminateConcat() { this->register_matcher(m, callback); } +pass::EliminateConcatSplit::EliminateConcatSplit() { + MATCHER_SCOPE(EliminateConcatSplit); + auto pattern_concat = pattern::wrap_type(); + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_map(); + const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); + if (!concat) { + return false; + } + const auto concat_users = concat->get_users(); + auto concat_inputs = concat->inputs(); + auto concat_axis = concat->get_axis(); + if (concat_axis < 0) + concat_axis = concat_axis + concat->get_output_shape(0).size(); + + bool need_opt = false; + std::vector, u_int64_t, u_int64_t>> slice_out_index_in_cocat; + for (const auto& user : concat_users) { + if (ov::is_type(user)) { + auto strided_slice_node = ov::as_type_ptr(user); + if (!strided_slice_node) { + return false; + } + // check that all values of the mask is equal 0 + auto check_mask = [](const std::vector& mask_to_check) { + auto it = std::find_if(mask_to_check.begin(), mask_to_check.end(), [](const int64_t& value) { + return value != 0; + }); + if (mask_to_check.empty() || it == mask_to_check.end()) { + return true; + } + return false; + }; + // check that we won't do change dimention rank + if (!check_mask(strided_slice_node->get_shrink_axis_mask()) || + !check_mask(strided_slice_node->get_new_axis_mask()) || + !check_mask(strided_slice_node->get_ellipsis_mask())) { + return false; + } + + auto begin_node = strided_slice_node->get_input_node_shared_ptr(1); + const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); + auto begin_values = begin_constant_node->cast_vector(); + auto begin_mask = strided_slice_node->get_begin_mask(); + + auto end_node = strided_slice_node->get_input_node_shared_ptr(2); + const auto& end_constant_node = ov::util::get_constant_from_source(end_node); + auto end_values = end_constant_node->cast_vector(); + auto end_mask = strided_slice_node->get_end_mask(); + + need_opt = true; + slice_out_index_in_cocat.push_back( + std::make_tuple(strided_slice_node, begin_values[concat_axis], end_values[concat_axis] - 1)); + } else { + need_opt = false; + } + } + if (!need_opt) + return false; + + uint64_t start_index = 0; + std::vector, u_int64_t, u_int64_t>> in_index_in_cocat; + for (auto& concat_in : concat_inputs) { + auto tmp_index = start_index + concat_in.get_shape()[concat_axis] - 1; + in_index_in_cocat.push_back( + std::make_tuple(concat_in.get_source_output().get_node_shared_ptr(), start_index, tmp_index)); + start_index = tmp_index + 1; + } + + std::vector, u_int64_t, u_int64_t>> mismatch_slices{}; + for (const auto& slice_out : slice_out_index_in_cocat) { + bool matched = false; + for (const auto& concat_in : in_index_in_cocat) { + if (get<1>(slice_out) == get<1>(concat_in) && get<2>(slice_out) == get<2>(concat_in)) { + auto slice_outputs = get<0>(slice_out)->outputs(); + for (auto& slice_output : slice_outputs) { + replace_output_update_name(slice_output, get<0>(concat_in)); + } + matched = true; + break; + } + } + if (!matched) + mismatch_slices.push_back(slice_out); + } + + u_int64_t new_start_value{std::numeric_limits::max()}; + u_int64_t new_end_value{0}; + for (const auto& mismatch_slice : mismatch_slices) { + for (const auto& concat_in : in_index_in_cocat) { + if ((get<1>(concat_in) <= get<1>(mismatch_slice)) && (get<2>(concat_in) > get<1>(mismatch_slice))) { + if (get<1>(concat_in) < new_start_value) + new_start_value = get<1>(concat_in); + if (get<2>(concat_in) > new_end_value) + new_end_value = get<2>(concat_in); + } + if ((get<1>(concat_in) < get<2>(mismatch_slice)) && (get<2>(concat_in) >= get<2>(mismatch_slice))) { + if (get<1>(concat_in) < new_start_value) + new_start_value = get<1>(concat_in); + if (get<2>(concat_in) > new_end_value) + new_end_value = get<2>(concat_in); + } + } + } + + std::vector> new_concat_in_nodes{}; + bool new_need = false; + for (const auto& concat_in : in_index_in_cocat) { + if (get<1>(concat_in) == new_start_value) { + new_need = true; + } + if (get<2>(concat_in) == new_end_value) { + new_concat_in_nodes.push_back(get<0>(concat_in)); + new_need = false; + } + if (new_need) { + new_concat_in_nodes.push_back(get<0>(concat_in)); + } + } + + auto new_concat_node = concat->clone_with_new_inputs(ov::as_output_vector(new_concat_in_nodes)); + replace_output_update_name(concat, new_concat_node); + + for (const auto& mismatch_slice : mismatch_slices) { + auto& slice_node = get<0>(mismatch_slice); + if (slice_node->get_users().size() == 1 && ov::is_type(slice_node->get_users()[0]) && + ov::as_type_ptr(slice_node->get_users()[0])->get_axis() == concat_axis) { + auto next_concat = ov::as_type_ptr(slice_node->get_users()[0]); + auto next_concat_inputs = next_concat->input_values(); + std::vector> new_next_concat_inputs{}; + for (const auto& t : next_concat_inputs) { + if (t.get_node_shared_ptr() == slice_node) { + for (const auto& need_insert : new_concat_in_nodes) { + new_next_concat_inputs.push_back(need_insert); + } + continue; + } + new_next_concat_inputs.push_back(t.get_node_shared_ptr()); + } + auto new_next_concat_node = + next_concat->clone_with_new_inputs(ov::as_output_vector(new_next_concat_inputs)); + replace_output_update_name(next_concat, new_next_concat_node); + } else { + std::vector> new_slice_in_nodes{}; + new_slice_in_nodes.push_back(new_concat_node); + + auto begin_node = slice_node->get_input_node_shared_ptr(1); + const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); + auto begin_values = begin_constant_node->cast_vector(); + begin_values[concat_axis] = get<1>(mismatch_slice) - new_start_value; + new_slice_in_nodes.push_back( + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{begin_values.size()}, begin_values)); + + auto end_node = slice_node->get_input_node_shared_ptr(2); + const auto& end_constant_node = ov::util::get_constant_from_source(end_node); + auto end_values = end_constant_node->cast_vector(); + end_values[concat_axis] = get<2>(mismatch_slice) - new_start_value + 1; + new_slice_in_nodes.push_back( + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{end_values.size()}, end_values)); + auto new_slice_node = + get<0>(mismatch_slice)->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); + replace_output_update_name(get<0>(mismatch_slice), new_slice_node); + } + } + return true; + }; + + auto m = make_shared(pattern_concat, matcher_name); + this->register_matcher(m, callback); +} + pass::EliminateSplit::EliminateSplit() { MATCHER_SCOPE(EliminateSplit); auto convert_pattern = pattern::wrap_type(); @@ -1154,6 +1325,7 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { ADD_MATCHER_FOR_THIS(EliminateConvert) ADD_MATCHER_FOR_THIS(EliminateConvertNonZero) ADD_MATCHER_FOR_THIS(EliminateConcat) + ADD_MATCHER_FOR_THIS(EliminateConcatSplit) ADD_MATCHER_FOR_THIS(EliminateSplit) ADD_MATCHER_FOR_THIS(EliminateTranspose) ADD_MATCHER_FOR_THIS(EliminateEltwise) diff --git a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp index 7da5c79981d415..f1fe20d85eabaf 100644 --- a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp @@ -1537,6 +1537,110 @@ TEST_F(TransformationTestsF, EliminateSliceBeforeGatherElements) { } } +TEST_F(TransformationTestsF, EliminateConcatSplit) { + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto param4 = make_shared(element::f32, Shape{2, 10, 6}); + auto concat = make_shared(ov::as_output_vector({param1, param2, param3, param4}), axis); + + auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto strided_slice1 = std::make_shared(concat, + begin_const1, + end_const1, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + + auto add_param = make_shared(element::f32, Shape{2, 10, 3}); + auto add = std::make_shared(strided_slice1, add_param); + auto result1 = std::make_shared(add); + + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 18}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice2); + auto result2 = std::make_shared(relu); + + model = std::make_shared(ResultVector{result1, result2}, + ParameterVector{param1, param2, param3, param4, add_param}); + manager.register_pass(); + } + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto param4 = make_shared(element::f32, Shape{2, 10, 6}); + auto add_param = make_shared(element::f32, Shape{2, 10, 3}); + auto add = std::make_shared(param1, add_param); + auto result1 = std::make_shared(add); + + auto concat = make_shared(ov::as_output_vector({param2, param3, param4}), axis); + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 15}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice2); + auto result2 = std::make_shared(relu); + + model_ref = std::make_shared(ResultVector{result1, result2}, + ParameterVector{param1, param2, param3, param4, add_param}); + } +} + +TEST_F(TransformationTestsF, EliminateConcatSplitConcat) { + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + auto concat = make_shared(ov::as_output_vector({param1, param2, param3}), axis); + + auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0}); + auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto strided_slice1 = std::make_shared(concat, + begin_const1, + end_const1, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto relu = std::make_shared(strided_slice1); + + auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3}); + auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 12}); + auto strided_slice2 = std::make_shared(concat, + begin_const2, + end_const2, + std::vector{1, 1, 0}, + std::vector{1, 1, 0}); + auto concat1 = make_shared(ov::as_output_vector({relu, strided_slice2}), axis); + + auto result = std::make_shared(concat1); + model = std::make_shared(ResultVector{result}, ParameterVector{param1, param2, param3}); + manager.register_pass(); + } + { + int64_t axis = 2; + auto param1 = make_shared(element::f32, Shape{2, 10, 3}); + auto param2 = make_shared(element::f32, Shape{2, 10, 4}); + auto param3 = make_shared(element::f32, Shape{2, 10, 5}); + + auto relu = std::make_shared(param1); + auto concat = make_shared(ov::as_output_vector({relu, param2, param3}), axis); + auto result = std::make_shared(concat); + + model_ref = std::make_shared(ResultVector{result}, ParameterVector{param1, param2, param3}); + } +} TEST_F(TransformationTestsF, EliminateStridedSlice) { { auto input = std::make_shared(ov::element::f32, diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index 185d3380f42859..7db48179d49c8d 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -172,6 +172,16 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector& mask) std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs); + if (new_args.size() == 3) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + m_begin_mask, + m_end_mask, + m_new_axis_mask, + m_shrink_axis_mask, + m_ellipsis_mask); + } check_new_args_count(this, new_args); return std::make_shared(new_args.at(0), new_args.at(1), From cf797d4e8574ec21275b5abfffddb029b7b625cb Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Mon, 10 Feb 2025 15:11:08 +0800 Subject: [PATCH 2/9] Fix error for dynamic shape Signed-off-by: Zhai, Xuejun --- .../common_optimizations/nop_elimination.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index a52c78d17c5cfa..6030c6ee7d05d1 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "compare.hpp" #include "itt.hpp" @@ -518,7 +519,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { matcher_pass_callback callback = [=](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_map(); const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); - if (!concat) { + if (!concat || concat->is_dynamic()) { return false; } const auto concat_users = concat->get_users(); @@ -528,7 +529,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { concat_axis = concat_axis + concat->get_output_shape(0).size(); bool need_opt = false; - std::vector, u_int64_t, u_int64_t>> slice_out_index_in_cocat; + std::vector, int64_t, int64_t>> slice_out_index_in_cocat; for (const auto& user : concat_users) { if (ov::is_type(user)) { auto strided_slice_node = ov::as_type_ptr(user); @@ -555,25 +556,24 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { auto begin_node = strided_slice_node->get_input_node_shared_ptr(1); const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); auto begin_values = begin_constant_node->cast_vector(); - auto begin_mask = strided_slice_node->get_begin_mask(); auto end_node = strided_slice_node->get_input_node_shared_ptr(2); const auto& end_constant_node = ov::util::get_constant_from_source(end_node); auto end_values = end_constant_node->cast_vector(); - auto end_mask = strided_slice_node->get_end_mask(); need_opt = true; slice_out_index_in_cocat.push_back( std::make_tuple(strided_slice_node, begin_values[concat_axis], end_values[concat_axis] - 1)); } else { need_opt = false; + break; } } - if (!need_opt) + if (!need_opt || (slice_out_index_in_cocat.size() == 1)) return false; uint64_t start_index = 0; - std::vector, u_int64_t, u_int64_t>> in_index_in_cocat; + std::vector, int64_t, int64_t>> in_index_in_cocat; for (auto& concat_in : concat_inputs) { auto tmp_index = start_index + concat_in.get_shape()[concat_axis] - 1; in_index_in_cocat.push_back( @@ -581,7 +581,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { start_index = tmp_index + 1; } - std::vector, u_int64_t, u_int64_t>> mismatch_slices{}; + std::vector, int64_t, int64_t>> mismatch_slices{}; for (const auto& slice_out : slice_out_index_in_cocat) { bool matched = false; for (const auto& concat_in : in_index_in_cocat) { @@ -598,8 +598,8 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { mismatch_slices.push_back(slice_out); } - u_int64_t new_start_value{std::numeric_limits::max()}; - u_int64_t new_end_value{0}; + int64_t new_start_value{std::numeric_limits::max()}; + int64_t new_end_value{0}; for (const auto& mismatch_slice : mismatch_slices) { for (const auto& concat_in : in_index_in_cocat) { if ((get<1>(concat_in) <= get<1>(mismatch_slice)) && (get<2>(concat_in) > get<1>(mismatch_slice))) { @@ -1325,13 +1325,13 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { ADD_MATCHER_FOR_THIS(EliminateConvert) ADD_MATCHER_FOR_THIS(EliminateConvertNonZero) ADD_MATCHER_FOR_THIS(EliminateConcat) - ADD_MATCHER_FOR_THIS(EliminateConcatSplit) ADD_MATCHER_FOR_THIS(EliminateSplit) ADD_MATCHER_FOR_THIS(EliminateTranspose) ADD_MATCHER_FOR_THIS(EliminateEltwise) ADD_MATCHER_FOR_THIS(EliminateSplitConcat) ADD_MATCHER_FOR_THIS(EliminateStridedSlice) ADD_MATCHER_FOR_THIS(EliminateSlice) + ADD_MATCHER_FOR_THIS(EliminateConcatSplit) // shape-dependent transformations if (use_shape_for_elimination) { From a70376de27c0b21c5c0f5f319e64f403f12e4e81 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Tue, 11 Feb 2025 11:37:28 +0800 Subject: [PATCH 3/9] Fix review comments Signed-off-by: Zhai, Xuejun --- .../common_optimizations/nop_elimination.cpp | 61 +++++++++---------- src/core/src/op/strided_slice.cpp | 10 ++- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 6030c6ee7d05d1..afc253256287d9 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -573,70 +573,69 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { return false; uint64_t start_index = 0; - std::vector, int64_t, int64_t>> in_index_in_cocat; + std::vector, int64_t, int64_t>> in_index_in_concat; for (auto& concat_in : concat_inputs) { auto tmp_index = start_index + concat_in.get_shape()[concat_axis] - 1; - in_index_in_cocat.push_back( + in_index_in_concat.push_back( std::make_tuple(concat_in.get_source_output().get_node_shared_ptr(), start_index, tmp_index)); start_index = tmp_index + 1; } std::vector, int64_t, int64_t>> mismatch_slices{}; - for (const auto& slice_out : slice_out_index_in_cocat) { + for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_cocat) { bool matched = false; - for (const auto& concat_in : in_index_in_cocat) { - if (get<1>(slice_out) == get<1>(concat_in) && get<2>(slice_out) == get<2>(concat_in)) { - auto slice_outputs = get<0>(slice_out)->outputs(); + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if (slice_begin == concat_input_begin && slice_end == concat_input_end) { + auto slice_outputs = slice_node->outputs(); for (auto& slice_output : slice_outputs) { - replace_output_update_name(slice_output, get<0>(concat_in)); + replace_output_update_name(slice_output, concat_input_node); } matched = true; break; } } if (!matched) - mismatch_slices.push_back(slice_out); + mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end)); } int64_t new_start_value{std::numeric_limits::max()}; int64_t new_end_value{0}; - for (const auto& mismatch_slice : mismatch_slices) { - for (const auto& concat_in : in_index_in_cocat) { - if ((get<1>(concat_in) <= get<1>(mismatch_slice)) && (get<2>(concat_in) > get<1>(mismatch_slice))) { - if (get<1>(concat_in) < new_start_value) - new_start_value = get<1>(concat_in); - if (get<2>(concat_in) > new_end_value) - new_end_value = get<2>(concat_in); + for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if ((concat_input_begin <= slice_begin) && (concat_input_end > slice_begin)) { + if (concat_input_begin < new_start_value) + new_start_value = concat_input_begin; + if (concat_input_end > new_end_value) + new_end_value = concat_input_end; } - if ((get<1>(concat_in) < get<2>(mismatch_slice)) && (get<2>(concat_in) >= get<2>(mismatch_slice))) { - if (get<1>(concat_in) < new_start_value) - new_start_value = get<1>(concat_in); - if (get<2>(concat_in) > new_end_value) - new_end_value = get<2>(concat_in); + if ((concat_input_begin < slice_end) && (concat_input_end >= slice_end)) { + if (concat_input_begin < new_start_value) + new_start_value = concat_input_begin; + if (concat_input_end > new_end_value) + new_end_value = concat_input_end; } } } std::vector> new_concat_in_nodes{}; bool new_need = false; - for (const auto& concat_in : in_index_in_cocat) { - if (get<1>(concat_in) == new_start_value) { + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + if (concat_input_begin == new_start_value) { new_need = true; } - if (get<2>(concat_in) == new_end_value) { - new_concat_in_nodes.push_back(get<0>(concat_in)); + if (concat_input_end == new_end_value) { + new_concat_in_nodes.push_back(concat_input_node); new_need = false; } if (new_need) { - new_concat_in_nodes.push_back(get<0>(concat_in)); + new_concat_in_nodes.push_back(concat_input_node); } } auto new_concat_node = concat->clone_with_new_inputs(ov::as_output_vector(new_concat_in_nodes)); replace_output_update_name(concat, new_concat_node); - for (const auto& mismatch_slice : mismatch_slices) { - auto& slice_node = get<0>(mismatch_slice); + for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { if (slice_node->get_users().size() == 1 && ov::is_type(slice_node->get_users()[0]) && ov::as_type_ptr(slice_node->get_users()[0])->get_axis() == concat_axis) { auto next_concat = ov::as_type_ptr(slice_node->get_users()[0]); @@ -661,19 +660,19 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { auto begin_node = slice_node->get_input_node_shared_ptr(1); const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node); auto begin_values = begin_constant_node->cast_vector(); - begin_values[concat_axis] = get<1>(mismatch_slice) - new_start_value; + begin_values[concat_axis] = slice_begin - new_start_value; new_slice_in_nodes.push_back( ov::op::v0::Constant::create(ov::element::i64, ov::Shape{begin_values.size()}, begin_values)); auto end_node = slice_node->get_input_node_shared_ptr(2); const auto& end_constant_node = ov::util::get_constant_from_source(end_node); auto end_values = end_constant_node->cast_vector(); - end_values[concat_axis] = get<2>(mismatch_slice) - new_start_value + 1; + end_values[concat_axis] = slice_end - new_start_value + 1; new_slice_in_nodes.push_back( ov::op::v0::Constant::create(ov::element::i64, ov::Shape{end_values.size()}, end_values)); auto new_slice_node = - get<0>(mismatch_slice)->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); - replace_output_update_name(get<0>(mismatch_slice), new_slice_node); + slice_node->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); + replace_output_update_name(slice_node, new_slice_node); } } return true; diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index 7db48179d49c8d..74fc2e75209062 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -172,6 +172,14 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector& mask) std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs); + // check_new_args_count(this, new_args); + NODE_VALIDATION_CHECK(this, + (new_args.size() == 3) || (new_args.size() == 4), + "clone_with_new_inputs() expected 3 or 4", + " arguments", + " but got ", + new_args.size()); + if (new_args.size() == 3) { return std::make_shared(new_args.at(0), new_args.at(1), @@ -182,7 +190,7 @@ std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& ne m_shrink_axis_mask, m_ellipsis_mask); } - check_new_args_count(this, new_args); + // check_new_args_count(this, new_args); return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), From 08552d8d67149c13e34b4e3ff23ef2f038300cf0 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Thu, 13 Feb 2025 10:09:58 +0800 Subject: [PATCH 4/9] Add some log for debug Signed-off-by: Zhai, Xuejun --- .../common_optimizations/nop_elimination.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index afc253256287d9..d09110586f41b5 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -598,6 +598,19 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end)); } + for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_cocat) { + std::cout << "[slice][node] " << slice_node->get_name() << ", begin: " << slice_begin + << ", end: " << slice_end << std::endl; + } + for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { + std::cout << "[concat][node] " << concat_input_node->get_name() << ", begin: " << concat_input_begin + << ", end: " << concat_input_end << std::endl; + } + for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { + std::cout << "[mismach][node] " << slice_node->get_name() << ", begin: " << slice_begin + << ", end: " << slice_end << std::endl; + } + int64_t new_start_value{std::numeric_limits::max()}; int64_t new_end_value{0}; for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { From b86195615ca4ac408ab7c2dd8a4d20af0dafc725 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Fri, 14 Feb 2025 19:33:45 +0800 Subject: [PATCH 5/9] Revert "Add some log for debug" This reverts commit 08552d8d67149c13e34b4e3ff23ef2f038300cf0. --- .../common_optimizations/nop_elimination.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index d09110586f41b5..afc253256287d9 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -598,19 +598,6 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end)); } - for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_cocat) { - std::cout << "[slice][node] " << slice_node->get_name() << ", begin: " << slice_begin - << ", end: " << slice_end << std::endl; - } - for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) { - std::cout << "[concat][node] " << concat_input_node->get_name() << ", begin: " << concat_input_begin - << ", end: " << concat_input_end << std::endl; - } - for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { - std::cout << "[mismach][node] " << slice_node->get_name() << ", begin: " << slice_begin - << ", end: " << slice_end << std::endl; - } - int64_t new_start_value{std::numeric_limits::max()}; int64_t new_end_value{0}; for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) { From eeb5347c4a40a5b77ba9c668ed4a90aa506b3a7a Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Fri, 14 Feb 2025 19:34:55 +0800 Subject: [PATCH 6/9] Fix test error Signed-off-by: Zhai, Xuejun --- .../transformations/common_optimizations/nop_elimination.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index afc253256287d9..d1ef185825da51 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -519,7 +519,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { matcher_pass_callback callback = [=](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_map(); const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); - if (!concat || concat->is_dynamic()) { + if (!concat || concat->is_dynamic() || concat->get_users().size()) { return false; } const auto concat_users = concat->get_users(); From a4fbe519f829223653a16657bbe162db6c5e0914 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Fri, 14 Feb 2025 22:24:24 +0800 Subject: [PATCH 7/9] Fix review comments Signed-off-by: Zhai, Xuejun --- .../common_optimizations/nop_elimination.cpp | 2 +- src/core/src/op/strided_slice.cpp | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index d1ef185825da51..57d5e61d8fb17f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -519,7 +519,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { matcher_pass_callback callback = [=](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_map(); const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); - if (!concat || concat->is_dynamic() || concat->get_users().size()) { + if (!concat || concat->is_dynamic() || concat->get_users().size()==1) { return false; } const auto concat_users = concat->get_users(); diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index 74fc2e75209062..e796b9a915ad4d 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -172,16 +172,11 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector& mask) std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs); - // check_new_args_count(this, new_args); - NODE_VALIDATION_CHECK(this, - (new_args.size() == 3) || (new_args.size() == 4), - "clone_with_new_inputs() expected 3 or 4", - " arguments", - " but got ", - new_args.size()); - - if (new_args.size() == 3) { - return std::make_shared(new_args.at(0), + auto args_size =new_args.size(); + NODE_VALIDATION_CHECK(this, (args_size == 3) || (args_size == 4), "Incorrect number of new inputs: ", args_size); + + if (args_size == 3) { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_begin_mask, @@ -190,7 +185,6 @@ std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& ne m_shrink_axis_mask, m_ellipsis_mask); } - // check_new_args_count(this, new_args); return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), From a0284a1de64d419a994c7343f176860387bb09f7 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Tue, 18 Feb 2025 09:17:29 +0800 Subject: [PATCH 8/9] Fix clang format issue Signed-off-by: Zhai, Xuejun --- src/core/src/op/strided_slice.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index e796b9a915ad4d..09d773c462dd03 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -172,18 +172,18 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector& mask) std::shared_ptr StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs); - auto args_size =new_args.size(); + auto args_size = new_args.size(); NODE_VALIDATION_CHECK(this, (args_size == 3) || (args_size == 4), "Incorrect number of new inputs: ", args_size); if (args_size == 3) { return std::make_shared(new_args.at(0), - new_args.at(1), - new_args.at(2), - m_begin_mask, - m_end_mask, - m_new_axis_mask, - m_shrink_axis_mask, - m_ellipsis_mask); + new_args.at(1), + new_args.at(2), + m_begin_mask, + m_end_mask, + m_new_axis_mask, + m_shrink_axis_mask, + m_ellipsis_mask); } return std::make_shared(new_args.at(0), new_args.at(1), From c27f95f1085ab3034cc90b30606578265416923b Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Tue, 18 Feb 2025 13:02:28 +0800 Subject: [PATCH 9/9] Fix clang format issue Signed-off-by: Zhai, Xuejun --- .../transformations/common_optimizations/nop_elimination.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 57d5e61d8fb17f..254246c4f8240f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -519,7 +519,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { matcher_pass_callback callback = [=](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_map(); const auto concat = ov::as_type_ptr(pattern_map.at(pattern_concat)); - if (!concat || concat->is_dynamic() || concat->get_users().size()==1) { + if (!concat || concat->is_dynamic() || concat->get_users().size() == 1) { return false; } const auto concat_users = concat->get_users(); @@ -670,8 +670,7 @@ pass::EliminateConcatSplit::EliminateConcatSplit() { end_values[concat_axis] = slice_end - new_start_value + 1; new_slice_in_nodes.push_back( ov::op::v0::Constant::create(ov::element::i64, ov::Shape{end_values.size()}, end_values)); - auto new_slice_node = - slice_node->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); + auto new_slice_node = slice_node->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes)); replace_output_update_name(slice_node, new_slice_node); } }