Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SegmentMax-16 reference implementation #28788

Merged
merged 47 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
d338b7e
WIP
p-wysocki Jan 27, 2025
462c43f
WIP
p-wysocki Jan 27, 2025
e3e44e2
Cleanup
p-wysocki Jan 27, 2025
adcce7d
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Jan 27, 2025
1affe5d
Fix two tests
p-wysocki Jan 27, 2025
1246f0a
Merge Core PR
p-wysocki Jan 28, 2025
b3a13d3
WIP
p-wysocki Jan 28, 2025
094131d
CR WIP
p-wysocki Jan 29, 2025
f8c9919
WIP CR
p-wysocki Jan 29, 2025
f841c4c
WIP CR
p-wysocki Jan 29, 2025
3e882ed
WIP
p-wysocki Jan 29, 2025
7cef3b5
Clenaup
p-wysocki Jan 29, 2025
36e4bf6
Merge Core PR
p-wysocki Jan 30, 2025
85bb1e6
WIP
p-wysocki Jan 30, 2025
53c02aa
WIP
p-wysocki Jan 30, 2025
5ff3e4f
Add an edge case to tests
p-wysocki Jan 30, 2025
3415e83
Cleanup
p-wysocki Feb 3, 2025
fb76824
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Feb 3, 2025
b8645d3
Apply CR
p-wysocki Feb 3, 2025
da74ce4
Merge upstream
p-wysocki Feb 3, 2025
4aed4e9
Merge branch 'segmentmax_core' into segmentmax_ref
p-wysocki Feb 3, 2025
2374c1b
Cleanup
p-wysocki Feb 3, 2025
8c86ddf
Fix git
p-wysocki Feb 3, 2025
b1c4ca4
Merge branch 'master' into segmentmax_ref
rkazants Feb 3, 2025
ca0e6f0
Merge branch 'master' into segmentmax_core
p-wysocki Feb 4, 2025
18396a6
Update src/core/shape_inference/include/segment_max_shape_inference.hpp
p-wysocki Feb 4, 2025
fea6b91
Update src/core/include/openvino/op/segment_max.hpp
p-wysocki Feb 4, 2025
4fcf899
Update src/core/tests/visitors/op/segment_max.cpp
p-wysocki Feb 4, 2025
303ad69
Update src/plugins/intel_cpu/tests/unit/shape_inference_test/segment_…
p-wysocki Feb 4, 2025
e33ca27
Update src/core/tests/type_prop/segment_max.cpp
p-wysocki Feb 4, 2025
8d091c5
Update src/core/tests/type_prop/segment_max.cpp
p-wysocki Feb 4, 2025
566e85e
Update src/core/include/openvino/op/segment_max.hpp
p-wysocki Feb 4, 2025
61979d0
Apply suggestions from code review
p-wysocki Feb 5, 2025
ae37028
Apply CR
p-wysocki Feb 5, 2025
303f383
Merge branch 'segmentmax_core' of https://github.com/p-wysocki/openvi…
p-wysocki Feb 5, 2025
58653ee
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Feb 5, 2025
7e46e7e
Merge branch 'segmentmax_core' into segmentmax_ref
p-wysocki Feb 5, 2025
e2f4838
CR
p-wysocki Feb 12, 2025
4d0ebe3
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Feb 12, 2025
cd76e31
Merge branch 'segmentmax_ref' of https://github.com/p-wysocki/openvin…
p-wysocki Feb 12, 2025
2123716
Update src/core/include/openvino/op/util/attr_types.hpp
p-wysocki Feb 12, 2025
4628d88
Apply CR
p-wysocki Feb 14, 2025
3257b11
Merge branch 'master' into segmentmax_ref
p-wysocki Feb 14, 2025
b832a52
Merge branch 'master' into segmentmax_ref
p-wysocki Feb 17, 2025
ca8fa02
Fix coverity issue
p-wysocki Feb 17, 2025
196f9a6
Fix clang
p-wysocki Feb 17, 2025
124de51
Attempt to fix win build
p-wysocki Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/include/openvino/op/util/attr_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ enum class PadMode { CONSTANT = 0, EDGE, REFLECT, SYMMETRIC };
OPENVINO_API
std::ostream& operator<<(std::ostream& s, const PadMode& type);

/// \brief Fill modes for the `SegmentMax` operator.
/// \brief Fill modes to set default value for operators like `SegmentMax`.
enum class FillMode { ZERO = 0, LOWEST };

OPENVINO_API
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset16_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)
// New operations added in opset16
_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)
55 changes: 55 additions & 0 deletions src/core/reference/include/openvino/reference/segment_max.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <algorithm>
#include <limits>
#include <vector>

#include "openvino/core/shape.hpp"

namespace ov::reference {

template <typename T, typename T_idx, std::enable_if_t<std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
void segment_max(const T* data,
const Shape& data_shape,
const T_idx* segment_ids,
T* out,
const Shape& output_shape,
const T empty_segment_value) {
const T_idx num_segments = output_shape[0];
const auto inner_dim_size = shape_size(data_shape.begin() + 1, data_shape.end());

// Initialize output with empty_segment_value
std::fill(out, out + num_segments * inner_dim_size, empty_segment_value);

// Iterate over each element in the first dimension
for (size_t i = 0; i < data_shape[0]; ++i) {
const T_idx segment_id = segment_ids[i];
if (segment_id >= num_segments) {
continue;
}
// Iterate over each element in the inner dimensions
for (size_t j = 0; j < inner_dim_size; ++j) {
const size_t index = i * inner_dim_size + j;
const size_t out_index = segment_id * inner_dim_size + j;
// Update the maximum value for the current segment and inner dimension
out[out_index] = std::max(out[out_index], data[index]);
}
}
}

template <typename T, typename T_idx, std::enable_if_t<!std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
void segment_max(const T* data,
const Shape& data_shape,
const T_idx* segment_ids,
T* out,
const Shape& output_shape,
const T empty_segment_value) {
std::vector<int64_t> segment_ids_int64(segment_ids, segment_ids + data_shape[0]);
segment_max(data, data_shape, segment_ids_int64.data(), out, output_shape, empty_segment_value);
}

} // namespace ov::reference
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ std::vector<TRShape> shape_infer(const SegmentMax* op,

// validate num_segments input
const auto num_segments_available = op->inputs().size() == 3;
const auto num_segments = num_segments_available ? get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor)
: ov::optional<TRShape>{};
const auto& num_segments = num_segments_available
? ov::op::get_input_const_data_as<TRShape, int64_t>(op, 2, tensor_accessor)
: ov::optional<std::vector<int64_t>>{};
if (num_segments_available) {
const auto& num_segments_shape = input_shapes[2];
NODE_SHAPE_INFER_CHECK(op,
Expand Down
2 changes: 1 addition & 1 deletion src/core/tests/opset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset13, 186},
OpsetTestParams{ov::get_opset14, 188},
OpsetTestParams{ov::get_opset15, 199},
OpsetTestParams{ov::get_opset16, 5}),
OpsetTestParams{ov::get_opset16, 6}),
OpsetTestNameGenerator{});

class MyOpOld : public ov::op::Op {
Expand Down
20 changes: 10 additions & 10 deletions src/core/tests/type_prop/segment_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace ov::test {
using op::v0::Constant, op::v0::Parameter, op::v1::Add, op::v1::ReduceMax, op::v1::StridedSlice, op::v3::ShapeOf;
using testing::HasSubstr;

class TypePropSegmentMaxTest : public TypePropOpTest<op::v16::SegmentMax> {};

Expand Down Expand Up @@ -69,45 +70,44 @@ TEST_F(TypePropSegmentMaxTest, incorrect_inputs) {
const auto num_segments_f32 = std::make_shared<Parameter>(element::f32, PartialShape{});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_f32, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The element type of the num_segments input be i32 or i64."));
HasSubstr("The element type of the num_segments input be i32 or i64."));
}
{
const auto segment_ids_f32 = std::make_shared<Parameter>(element::f32, PartialShape{3});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_f32, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The element type of the segment_ids input be i32 or i64."));
HasSubstr("The element type of the segment_ids input be i32 or i64."));
}
{
const auto segment_ids_nd = std::make_shared<Parameter>(element::i32, PartialShape{2, 3});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_nd, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("segment_ids must be a 1D input."));
HasSubstr("segment_ids must be a 1D input."));
}
{
const auto num_segments_nd = std::make_shared<Parameter>(element::i32, PartialShape{1});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_nd, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("num_segments must be a scalar input."));
HasSubstr("num_segments must be a scalar input."));
}
{
const auto segment_ids_unsorted =
std::make_shared<Constant>(element::i32, Shape{3}, std::vector<int64_t>{1, 0, 1});
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_unsorted, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("segment_ids must be sorted."));
HasSubstr("segment_ids must be sorted."));
}
{
const auto data_scalar = std::make_shared<Parameter>(element::i32, PartialShape{});
OV_EXPECT_THROW(std::ignore = make_op(data_scalar, segment_ids, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The data input cannot be a scalar."));
HasSubstr("The data input cannot be a scalar."));
}
{
const auto segment_ids_short = std::make_shared<Constant>(element::i32, Shape{2}, std::vector<int64_t>{1, 0});
OV_EXPECT_THROW(
std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
testing::HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
ov::NodeValidationFailure,
HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,7 @@ extern template bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov
extern template bool evaluate_node<ov::op::v16::Identity>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
80 changes: 80 additions & 0 deletions src/plugins/template/backend/ops/segment_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/reference/segment_max.hpp"

#include "element_visitor.hpp"
#include "evaluate_node.hpp"
#include "segment_max_shape_inference.hpp"

template <ov::element::Type_t ET_data, ov::element::Type_t ET_idx>
bool evaluate_index_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T_data = typename ov::element_type_traits<ET_data>::value_type;
using T_idx = typename ov::element_type_traits<ET_idx>::value_type;
auto input_shapes = std::vector<ov::PartialShape>{op->get_input_shape(0), op->get_input_shape(1)};
if (op->inputs().size() == 3) {
input_shapes.emplace_back(op->get_input_shape(2));
}
const auto output_shape =
ov::op::v16::shape_infer(op.get(), input_shapes, make_tensor_accessor(inputs)).front().to_shape();
outputs.front().set_shape(output_shape);
const auto empty_segment_value =
op->get_fill_mode() == ov::op::FillMode::ZERO ? T_data(0) : std::numeric_limits<T_data>::lowest();
ov::reference::segment_max(inputs[0].data<const T_data>(),
inputs[0].get_shape(),
inputs[1].data<const T_idx>(),
outputs[0].data<T_data>(),
outputs[0].get_shape(),
empty_segment_value);
return true;
}

template <ov::element::Type_t ET_data>
bool evaluate_data_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& index_type = op->get_input_element_type(1);
using ov::op::v16::SegmentMax;
using namespace ov::element;
switch (index_type) {
case i32:
return evaluate_index_type<ET_data, i32>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
case i64:
return evaluate_index_type<ET_data, i64>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
default:
OPENVINO_THROW("Unhandled index type ", index_type, " in evaluate_node()");
}
}

template <>
bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_output_element_type(0);

using ov::op::v16::SegmentMax;
using namespace ov::element;
switch (element_type) {
case i8:
return evaluate_data_type<i8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case i32:
return evaluate_data_type<i32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case i64:
return evaluate_data_type<i64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u8:
return evaluate_data_type<u8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u32:
return evaluate_data_type<u32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case u64:
return evaluate_data_type<u64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case f16:
return evaluate_data_type<f16>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
case f32:
return evaluate_data_type<f32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
default:
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
}
}
1 change: 1 addition & 0 deletions src/plugins/template/backend/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ _OPENVINO_OP_REG(SearchSorted, ov::op::v15)

_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
Expand Down
Loading
Loading