Skip to content

Commit

Permalink
Add SegmentMax-16 to Core (#28698)
Browse files Browse the repository at this point in the history
### Details:
 - Add `SegmentMax-16` to Core
 - Add shape inference function
 - Add tests
- A step towards enabling
https://www.tensorflow.org/api_docs/python/tf/math/segment_max and
https://www.tensorflow.org/api_docs/python/tf/raw_ops/SegmentMaxV2

### Related PRs:
 - #28103
 - #28788

### Tickets:
 - CVS-158916

---------

Signed-off-by: p-wysocki <[email protected]>
Co-authored-by: Pawel Raasz <[email protected]>
Co-authored-by: Katarzyna Mitrus <[email protected]>
  • Loading branch information
3 people authored Feb 13, 2025
1 parent 6aa2544 commit d455dd2
Show file tree
Hide file tree
Showing 10 changed files with 766 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/search_sorted.hpp"
#include "openvino/op/segment_max.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/selu.hpp"
#include "openvino/op/shape_of.hpp"
Expand Down
46 changes: 46 additions & 0 deletions src/core/include/openvino/op/segment_max.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov::op::v16 {
/// \brief An operation which computes the maximum values along segments of a tensor.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API SegmentMax : public ov::op::Op {
public:
OPENVINO_OP("SegmentMax", "opset16", ov::op::Op);

SegmentMax() = default;

/// \brief Constructs a SegmentMax operation.
///
/// \param data Input tensor with data
/// \param segment_ids Indices of segments in the data input tensor
/// \param fill_mode The value assigned to segments which are empty
SegmentMax(const Output<Node>& data, const Output<Node>& segment_ids, const op::FillMode fill_mode);

/// \brief Constructs a SegmentMax operation.
///
/// \param data Input tensor with data
/// \param segment_ids Indices of segments in the data input tensor
/// \param num_segments The segments count
/// \param fill_mode The value assigned to segments which are empty
SegmentMax(const Output<Node>& data,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const op::FillMode fill_mode);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

const op::FillMode get_fill_mode() const;

private:
op::FillMode m_fill_mode{};
};

} // namespace ov::op::v16
14 changes: 14 additions & 0 deletions src/core/include/openvino/op/util/attr_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ enum class PadMode { CONSTANT = 0, EDGE, REFLECT, SYMMETRIC };
OPENVINO_API
std::ostream& operator<<(std::ostream& s, const PadMode& type);

/// \brief Fill modes for the `SegmentMax` operator.
enum class FillMode { ZERO = 0, LOWEST };

OPENVINO_API
std::ostream& operator<<(std::ostream& s, const FillMode& type);

/// \brief Padding Type used for `Convolution` and `Pooling`
///
/// Follows ONNX padding type definitions
Expand Down Expand Up @@ -219,6 +225,14 @@ class OPENVINO_API AttributeAdapter<op::PadMode> : public EnumAttributeAdapterBa
OPENVINO_RTTI("AttributeAdapter<PadMode>");
};

template <>
class OPENVINO_API AttributeAdapter<op::FillMode> : public EnumAttributeAdapterBase<op::FillMode> {
public:
AttributeAdapter(op::FillMode& value) : EnumAttributeAdapterBase<op::FillMode>(value) {}

OPENVINO_RTTI("AttributeAdapter<FillMode>");
};

template <>
class OPENVINO_API AttributeAdapter<op::PadType> : public EnumAttributeAdapterBase<op::PadType> {
public:
Expand Down
82 changes: 82 additions & 0 deletions src/core/shape_inference/include/segment_max_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <cmath>

#include "openvino/op/segment_max.hpp"
#include "utils.hpp"

namespace ov {
namespace op {
namespace v16 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const SegmentMax* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 || input_shapes.size() == 3);

// validate shape of data input
const auto& data_shape = input_shapes[0];
const auto is_data_shape_rank_static = data_shape.rank().is_static();
if (is_data_shape_rank_static) {
NODE_SHAPE_INFER_CHECK(op, input_shapes, data_shape.size() > 0, "The data input cannot be a scalar.");
}

// validate segment_ids input
const auto& segment_ids_shape = input_shapes[1];
const auto is_segment_ids_rank_static = segment_ids_shape.rank().is_static();
if (is_segment_ids_rank_static) {
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
segment_ids_shape.size() == 1,
"segment_ids must be a 1D input. Got: ",
segment_ids_shape);
if (is_data_shape_rank_static) {
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
data_shape[0].compatible(segment_ids_shape[0]),
"The number of elements in segment_ids must match the first dimension of data.");
}
}
const auto segment_ids = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor);
if (segment_ids) {
NODE_VALIDATION_CHECK(op,
std::is_sorted(segment_ids->begin(), segment_ids->end()),
"segment_ids must be sorted.");
}

// validate num_segments input
const auto num_segments_available = op->inputs().size() == 3;
const auto num_segments = num_segments_available ? get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor)
: ov::optional<TRShape>{};
if (num_segments_available) {
const auto& num_segments_shape = input_shapes[2];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
num_segments_shape.rank().compatible(0),
"num_segments must be a scalar input. Got: ",
num_segments_shape);
}

if (!is_data_shape_rank_static) {
return {PartialShape::dynamic()};
}
using TDim = typename TShape::value_type;
auto output_shapes = std::vector<TRShape>{data_shape};
auto& output_shape = output_shapes[0];
if (num_segments) {
output_shape[0] = TDim((*num_segments)[0]);
} else if (segment_ids && !num_segments_available) {
output_shape[0] = TDim(segment_ids->back() + 1);
} else {
// if num_segments input is provided but the value is unknown, the first dimension should be dynamic
output_shape[0] = Dimension::dynamic();
}
return output_shapes;
}
} // namespace v16
} // namespace op
} // namespace ov
72 changes: 72 additions & 0 deletions src/core/src/op/segment_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/segment_max.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/op.hpp"
#include "segment_max_shape_inference.hpp"

namespace ov {
namespace op {
namespace v16 {

SegmentMax::SegmentMax(const Output<Node>& data, const Output<Node>& segment_ids, const op::FillMode fill_mode)
: Op({data, segment_ids}),
m_fill_mode(fill_mode) {
constructor_validate_and_infer_types();
}

SegmentMax::SegmentMax(const Output<Node>& data,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const op::FillMode fill_mode)
: Op({data, segment_ids, num_segments}),
m_fill_mode(fill_mode) {
constructor_validate_and_infer_types();
}

bool SegmentMax::visit_attributes(ov::AttributeVisitor& visitor) {
OV_OP_SCOPE(v16_SegmentMax_visit_attributes);
visitor.on_attribute("fill_mode", m_fill_mode);
return true;
}

void SegmentMax::validate_and_infer_types() {
OV_OP_SCOPE(v16_SegmentMax_validate_and_infer_types);
const auto& segment_ids_element_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
segment_ids_element_type == element::i32 || segment_ids_element_type == element::i64,
"The element type of the segment_ids input be i32 or i64. Got: ",
segment_ids_element_type);
if (inputs().size() == 3) {
const auto& num_segments_element_type = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
num_segments_element_type == element::i32 || num_segments_element_type == element::i64,
"The element type of the num_segments input be i32 or i64. Got: ",
num_segments_element_type);
}

const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}

std::shared_ptr<Node> SegmentMax::clone_with_new_inputs(const ov::OutputVector& new_args) const {
OV_OP_SCOPE(v16_SegmentMax_clone_with_new_inputs);
check_new_args_count(this, new_args);
if (new_args.size() == 3) {
return std::make_shared<SegmentMax>(new_args.at(0), new_args.at(1), new_args.at(2), m_fill_mode);
} else {
return std::make_shared<SegmentMax>(new_args.at(0), new_args.at(1), m_fill_mode);
}
}

const op::FillMode SegmentMax::get_fill_mode() const {
return m_fill_mode;
}

} // namespace v16
} // namespace op
} // namespace ov
12 changes: 12 additions & 0 deletions src/core/src/op/util/attr_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ OPENVINO_API EnumNames<ov::op::PadMode>& EnumNames<ov::op::PadMode>::get() {
return enum_names;
}

template <>
OPENVINO_API EnumNames<ov::op::FillMode>& EnumNames<ov::op::FillMode>::get() {
static auto enum_names =
EnumNames<ov::op::FillMode>("ov::op::FillMode",
{{"zero", ov::op::FillMode::ZERO}, {"lowest", ov::op::FillMode::LOWEST}});
return enum_names;
}

template <>
OPENVINO_API EnumNames<ov::op::PadType>& EnumNames<ov::op::PadType>::get() {
static auto enum_names = EnumNames<ov::op::PadType>("ov::op::PadType",
Expand Down Expand Up @@ -132,6 +140,10 @@ std::ostream& op::operator<<(std::ostream& s, const ov::op::PadMode& type) {
return s << as_string(type);
}

std::ostream& op::operator<<(std::ostream& s, const ov::op::FillMode& type) {
return s << as_string(type);
}

std::ostream& op::operator<<(std::ostream& s, const ov::op::PadType& type) {
return s << as_string(type);
}
Expand Down
Loading

0 comments on commit d455dd2

Please sign in to comment.