Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Common][Transformations] eliminated no needed split & concat #28827

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API EliminateConcat;
class TRANSFORMATIONS_API EliminateConcatSplit;
class TRANSFORMATIONS_API EliminateConvert;
class TRANSFORMATIONS_API EliminateConvertNonZero;
class TRANSFORMATIONS_API EliminateEltwise;
Expand Down Expand Up @@ -83,6 +84,16 @@ class ov::pass::EliminateConcat : public ov::pass::MatcherPass {
EliminateConcat();
};

/**
* @ingroup ov_transformation_common_api
* @brief EliminateConcatSplit eliminates split from concat that no need
*/
class ov::pass::EliminateConcatSplit : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("EliminateConcatSplit");
EliminateConcatSplit();
};

/**
* @ingroup ov_transformation_common_api
* @brief EliminateSplit eliminates split that does nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <functional>
#include <memory>
#include <numeric>
#include <tuple>

#include "compare.hpp"
#include "itt.hpp"
Expand Down Expand Up @@ -512,6 +513,174 @@ pass::EliminateConcat::EliminateConcat() {
this->register_matcher(m, callback);
}

pass::EliminateConcatSplit::EliminateConcatSplit() {
MATCHER_SCOPE(EliminateConcatSplit);
auto pattern_concat = pattern::wrap_type<ov::op::v0::Concat>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_map();
const auto concat = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(pattern_concat));
if (!concat || concat->is_dynamic() || concat->get_users().size() == 1) {
return false;
}
const auto concat_users = concat->get_users();
auto concat_inputs = concat->inputs();
auto concat_axis = concat->get_axis();
if (concat_axis < 0)
concat_axis = concat_axis + concat->get_output_shape(0).size();

bool need_opt = false;
std::vector<std::tuple<std::shared_ptr<Node>, int64_t, int64_t>> slice_out_index_in_cocat;
for (const auto& user : concat_users) {
if (ov::is_type<ov::op::v1::StridedSlice>(user)) {
auto strided_slice_node = ov::as_type_ptr<ov::op::v1::StridedSlice>(user);
if (!strided_slice_node) {
return false;
}
// check that all values of the mask is equal 0
auto check_mask = [](const std::vector<int64_t>& mask_to_check) {
auto it = std::find_if(mask_to_check.begin(), mask_to_check.end(), [](const int64_t& value) {
return value != 0;
});
if (mask_to_check.empty() || it == mask_to_check.end()) {
return true;
}
return false;
};
// check that we won't do change dimention rank
if (!check_mask(strided_slice_node->get_shrink_axis_mask()) ||
!check_mask(strided_slice_node->get_new_axis_mask()) ||
!check_mask(strided_slice_node->get_ellipsis_mask())) {
return false;
}

auto begin_node = strided_slice_node->get_input_node_shared_ptr(1);
const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node);
auto begin_values = begin_constant_node->cast_vector<int64_t>();

auto end_node = strided_slice_node->get_input_node_shared_ptr(2);
const auto& end_constant_node = ov::util::get_constant_from_source(end_node);
auto end_values = end_constant_node->cast_vector<int64_t>();

need_opt = true;
slice_out_index_in_cocat.push_back(
std::make_tuple(strided_slice_node, begin_values[concat_axis], end_values[concat_axis] - 1));
} else {
need_opt = false;
break;
}
}
if (!need_opt || (slice_out_index_in_cocat.size() == 1))
return false;

uint64_t start_index = 0;
std::vector<std::tuple<std::shared_ptr<Node>, int64_t, int64_t>> in_index_in_concat;
for (auto& concat_in : concat_inputs) {
auto tmp_index = start_index + concat_in.get_shape()[concat_axis] - 1;
in_index_in_concat.push_back(
std::make_tuple(concat_in.get_source_output().get_node_shared_ptr(), start_index, tmp_index));
start_index = tmp_index + 1;
}

std::vector<std::tuple<std::shared_ptr<Node>, int64_t, int64_t>> mismatch_slices{};
for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_cocat) {
bool matched = false;
for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) {
if (slice_begin == concat_input_begin && slice_end == concat_input_end) {
auto slice_outputs = slice_node->outputs();
for (auto& slice_output : slice_outputs) {
replace_output_update_name(slice_output, concat_input_node);
}
matched = true;
break;
}
}
if (!matched)
mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end));
}

int64_t new_start_value{std::numeric_limits<int64_t>::max()};
int64_t new_end_value{0};
for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) {
for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) {
if ((concat_input_begin <= slice_begin) && (concat_input_end > slice_begin)) {
if (concat_input_begin < new_start_value)
new_start_value = concat_input_begin;
if (concat_input_end > new_end_value)
new_end_value = concat_input_end;
}
if ((concat_input_begin < slice_end) && (concat_input_end >= slice_end)) {
if (concat_input_begin < new_start_value)
new_start_value = concat_input_begin;
if (concat_input_end > new_end_value)
new_end_value = concat_input_end;
}
}
}

std::vector<std::shared_ptr<Node>> new_concat_in_nodes{};
bool new_need = false;
for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) {
if (concat_input_begin == new_start_value) {
new_need = true;
}
if (concat_input_end == new_end_value) {
new_concat_in_nodes.push_back(concat_input_node);
new_need = false;
}
if (new_need) {
new_concat_in_nodes.push_back(concat_input_node);
}
}

auto new_concat_node = concat->clone_with_new_inputs(ov::as_output_vector(new_concat_in_nodes));
replace_output_update_name(concat, new_concat_node);

for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) {
if (slice_node->get_users().size() == 1 && ov::is_type<ov::op::v0::Concat>(slice_node->get_users()[0]) &&
ov::as_type_ptr<ov::op::v0::Concat>(slice_node->get_users()[0])->get_axis() == concat_axis) {
auto next_concat = ov::as_type_ptr<ov::op::v0::Concat>(slice_node->get_users()[0]);
auto next_concat_inputs = next_concat->input_values();
std::vector<std::shared_ptr<Node>> new_next_concat_inputs{};
for (const auto& t : next_concat_inputs) {
if (t.get_node_shared_ptr() == slice_node) {
for (const auto& need_insert : new_concat_in_nodes) {
new_next_concat_inputs.push_back(need_insert);
}
continue;
}
new_next_concat_inputs.push_back(t.get_node_shared_ptr());
}
auto new_next_concat_node =
next_concat->clone_with_new_inputs(ov::as_output_vector(new_next_concat_inputs));
replace_output_update_name(next_concat, new_next_concat_node);
} else {
std::vector<std::shared_ptr<Node>> new_slice_in_nodes{};
new_slice_in_nodes.push_back(new_concat_node);

auto begin_node = slice_node->get_input_node_shared_ptr(1);
const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node);
auto begin_values = begin_constant_node->cast_vector<int64_t>();
begin_values[concat_axis] = slice_begin - new_start_value;
new_slice_in_nodes.push_back(
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{begin_values.size()}, begin_values));

auto end_node = slice_node->get_input_node_shared_ptr(2);
const auto& end_constant_node = ov::util::get_constant_from_source(end_node);
auto end_values = end_constant_node->cast_vector<int64_t>();
end_values[concat_axis] = slice_end - new_start_value + 1;
new_slice_in_nodes.push_back(
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{end_values.size()}, end_values));
auto new_slice_node = slice_node->clone_with_new_inputs(ov::as_output_vector(new_slice_in_nodes));
replace_output_update_name(slice_node, new_slice_node);
}
}
return true;
};

auto m = make_shared<pattern::Matcher>(pattern_concat, matcher_name);
this->register_matcher(m, callback);
}

pass::EliminateSplit::EliminateSplit() {
MATCHER_SCOPE(EliminateSplit);
auto convert_pattern = pattern::wrap_type<ov::op::v1::Split>();
Expand Down Expand Up @@ -1160,6 +1329,7 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
ADD_MATCHER_FOR_THIS(EliminateSplitConcat)
ADD_MATCHER_FOR_THIS(EliminateStridedSlice)
ADD_MATCHER_FOR_THIS(EliminateSlice)
ADD_MATCHER_FOR_THIS(EliminateConcatSplit)

// shape-dependent transformations
if (use_shape_for_elimination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,110 @@ TEST_F(TransformationTestsF, EliminateSliceBeforeGatherElements) {
}
}

TEST_F(TransformationTestsF, EliminateConcatSplit) {
{
int64_t axis = 2;
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
auto param4 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 6});
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3, param4}), axis);

auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
begin_const1,
end_const1,
std::vector<int64_t>{1, 1, 0},
std::vector<int64_t>{1, 1, 0});

auto add_param = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto add = std::make_shared<op::v1::Add>(strided_slice1, add_param);
auto result1 = std::make_shared<op::v0::Result>(add);

auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 18});
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
begin_const2,
end_const2,
std::vector<int64_t>{1, 1, 0},
std::vector<int64_t>{1, 1, 0});
auto relu = std::make_shared<op::v0::Relu>(strided_slice2);
auto result2 = std::make_shared<op::v0::Result>(relu);

model = std::make_shared<ov::Model>(ResultVector{result1, result2},
ParameterVector{param1, param2, param3, param4, add_param});
manager.register_pass<ov::pass::EliminateConcatSplit>();
}
{
int64_t axis = 2;
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
auto param4 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 6});
auto add_param = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto add = std::make_shared<op::v1::Add>(param1, add_param);
auto result1 = std::make_shared<op::v0::Result>(add);

auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param2, param3, param4}), axis);
auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 15});
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
begin_const2,
end_const2,
std::vector<int64_t>{1, 1, 0},
std::vector<int64_t>{1, 1, 0});
auto relu = std::make_shared<op::v0::Relu>(strided_slice2);
auto result2 = std::make_shared<op::v0::Result>(relu);

model_ref = std::make_shared<ov::Model>(ResultVector{result1, result2},
ParameterVector{param1, param2, param3, param4, add_param});
}
}

TEST_F(TransformationTestsF, EliminateConcatSplitConcat) {
{
int64_t axis = 2;
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3}), axis);

auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
begin_const1,
end_const1,
std::vector<int64_t>{1, 1, 0},
std::vector<int64_t>{1, 1, 0});
auto relu = std::make_shared<op::v0::Relu>(strided_slice1);

auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 12});
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
begin_const2,
end_const2,
std::vector<int64_t>{1, 1, 0},
std::vector<int64_t>{1, 1, 0});
auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector({relu, strided_slice2}), axis);

auto result = std::make_shared<op::v0::Result>(concat1);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
manager.register_pass<ov::pass::EliminateConcatSplit>();
}
{
int64_t axis = 2;
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});

auto relu = std::make_shared<op::v0::Relu>(param1);
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({relu, param2, param3}), axis);
auto result = std::make_shared<op::v0::Result>(concat);

model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
}
}
TEST_F(TransformationTestsF, EliminateStridedSlice) {
{
auto input = std::make_shared<op::v0::Parameter>(ov::element::f32,
Expand Down
14 changes: 13 additions & 1 deletion src/core/src/op/strided_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,19 @@ AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector<int64_t>& mask)

std::shared_ptr<Node> StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs);
check_new_args_count(this, new_args);
auto args_size = new_args.size();
NODE_VALIDATION_CHECK(this, (args_size == 3) || (args_size == 4), "Incorrect number of new inputs: ", args_size);

if (args_size == 3) {
return std::make_shared<StridedSlice>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_begin_mask,
m_end_mask,
m_new_axis_mask,
m_shrink_axis_mask,
m_ellipsis_mask);
}
return std::make_shared<v1::StridedSlice>(new_args.at(0),
new_args.at(1),
new_args.at(2),
Expand Down
Loading