-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### Details: - Add `SegmentMax-16` to Core - Add shape inference function - Add tests - A step towards enabling https://www.tensorflow.org/api_docs/python/tf/math/segment_max and https://www.tensorflow.org/api_docs/python/tf/raw_ops/SegmentMaxV2 ### Related PRs: - #28103 - #28788 ### Tickets: - CVS-158916 --------- Signed-off-by: p-wysocki <[email protected]> Co-authored-by: Pawel Raasz <[email protected]> Co-authored-by: Katarzyna Mitrus <[email protected]>
- Loading branch information
1 parent
6aa2544
commit d455dd2
Showing
10 changed files
with
766 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov::op::v16 { | ||
/// \brief An operation which computes the maximum values along segments of a tensor. | ||
/// \ingroup ov_ops_cpp_api | ||
class OPENVINO_API SegmentMax : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("SegmentMax", "opset16", ov::op::Op); | ||
|
||
SegmentMax() = default; | ||
|
||
/// \brief Constructs a SegmentMax operation. | ||
/// | ||
/// \param data Input tensor with data | ||
/// \param segment_ids Indices of segments in the data input tensor | ||
/// \param fill_mode The value assigned to segments which are empty | ||
SegmentMax(const Output<Node>& data, const Output<Node>& segment_ids, const op::FillMode fill_mode); | ||
|
||
/// \brief Constructs a SegmentMax operation. | ||
/// | ||
/// \param data Input tensor with data | ||
/// \param segment_ids Indices of segments in the data input tensor | ||
/// \param num_segments The segments count | ||
/// \param fill_mode The value assigned to segments which are empty | ||
SegmentMax(const Output<Node>& data, | ||
const Output<Node>& segment_ids, | ||
const Output<Node>& num_segments, | ||
const op::FillMode fill_mode); | ||
|
||
bool visit_attributes(AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; | ||
|
||
const op::FillMode get_fill_mode() const; | ||
|
||
private: | ||
op::FillMode m_fill_mode{}; | ||
}; | ||
|
||
} // namespace ov::op::v16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
src/core/shape_inference/include/segment_max_shape_inference.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <cmath> | ||
|
||
#include "openvino/op/segment_max.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v16 { | ||
template <class TShape, class TRShape = result_shape_t<TShape>> | ||
std::vector<TRShape> shape_infer(const SegmentMax* op, | ||
const std::vector<TShape>& input_shapes, | ||
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { | ||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 || input_shapes.size() == 3); | ||
|
||
// validate shape of data input | ||
const auto& data_shape = input_shapes[0]; | ||
const auto is_data_shape_rank_static = data_shape.rank().is_static(); | ||
if (is_data_shape_rank_static) { | ||
NODE_SHAPE_INFER_CHECK(op, input_shapes, data_shape.size() > 0, "The data input cannot be a scalar."); | ||
} | ||
|
||
// validate segment_ids input | ||
const auto& segment_ids_shape = input_shapes[1]; | ||
const auto is_segment_ids_rank_static = segment_ids_shape.rank().is_static(); | ||
if (is_segment_ids_rank_static) { | ||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
segment_ids_shape.size() == 1, | ||
"segment_ids must be a 1D input. Got: ", | ||
segment_ids_shape); | ||
if (is_data_shape_rank_static) { | ||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
data_shape[0].compatible(segment_ids_shape[0]), | ||
"The number of elements in segment_ids must match the first dimension of data."); | ||
} | ||
} | ||
const auto segment_ids = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor); | ||
if (segment_ids) { | ||
NODE_VALIDATION_CHECK(op, | ||
std::is_sorted(segment_ids->begin(), segment_ids->end()), | ||
"segment_ids must be sorted."); | ||
} | ||
|
||
// validate num_segments input | ||
const auto num_segments_available = op->inputs().size() == 3; | ||
const auto num_segments = num_segments_available ? get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor) | ||
: ov::optional<TRShape>{}; | ||
if (num_segments_available) { | ||
const auto& num_segments_shape = input_shapes[2]; | ||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
num_segments_shape.rank().compatible(0), | ||
"num_segments must be a scalar input. Got: ", | ||
num_segments_shape); | ||
} | ||
|
||
if (!is_data_shape_rank_static) { | ||
return {PartialShape::dynamic()}; | ||
} | ||
using TDim = typename TShape::value_type; | ||
auto output_shapes = std::vector<TRShape>{data_shape}; | ||
auto& output_shape = output_shapes[0]; | ||
if (num_segments) { | ||
output_shape[0] = TDim((*num_segments)[0]); | ||
} else if (segment_ids && !num_segments_available) { | ||
output_shape[0] = TDim(segment_ids->back() + 1); | ||
} else { | ||
// if num_segments input is provided but the value is unknown, the first dimension should be dynamic | ||
output_shape[0] = Dimension::dynamic(); | ||
} | ||
return output_shapes; | ||
} | ||
} // namespace v16 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/segment_max.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/validation_util.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "segment_max_shape_inference.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v16 { | ||
|
||
SegmentMax::SegmentMax(const Output<Node>& data, const Output<Node>& segment_ids, const op::FillMode fill_mode) | ||
: Op({data, segment_ids}), | ||
m_fill_mode(fill_mode) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
SegmentMax::SegmentMax(const Output<Node>& data, | ||
const Output<Node>& segment_ids, | ||
const Output<Node>& num_segments, | ||
const op::FillMode fill_mode) | ||
: Op({data, segment_ids, num_segments}), | ||
m_fill_mode(fill_mode) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
bool SegmentMax::visit_attributes(ov::AttributeVisitor& visitor) { | ||
OV_OP_SCOPE(v16_SegmentMax_visit_attributes); | ||
visitor.on_attribute("fill_mode", m_fill_mode); | ||
return true; | ||
} | ||
|
||
void SegmentMax::validate_and_infer_types() { | ||
OV_OP_SCOPE(v16_SegmentMax_validate_and_infer_types); | ||
const auto& segment_ids_element_type = get_input_element_type(1); | ||
NODE_VALIDATION_CHECK(this, | ||
segment_ids_element_type == element::i32 || segment_ids_element_type == element::i64, | ||
"The element type of the segment_ids input be i32 or i64. Got: ", | ||
segment_ids_element_type); | ||
if (inputs().size() == 3) { | ||
const auto& num_segments_element_type = get_input_element_type(2); | ||
NODE_VALIDATION_CHECK(this, | ||
num_segments_element_type == element::i32 || num_segments_element_type == element::i64, | ||
"The element type of the num_segments input be i32 or i64. Got: ", | ||
num_segments_element_type); | ||
} | ||
|
||
const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this)); | ||
set_output_type(0, get_input_element_type(0), output_shapes[0]); | ||
} | ||
|
||
std::shared_ptr<Node> SegmentMax::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
OV_OP_SCOPE(v16_SegmentMax_clone_with_new_inputs); | ||
check_new_args_count(this, new_args); | ||
if (new_args.size() == 3) { | ||
return std::make_shared<SegmentMax>(new_args.at(0), new_args.at(1), new_args.at(2), m_fill_mode); | ||
} else { | ||
return std::make_shared<SegmentMax>(new_args.at(0), new_args.at(1), m_fill_mode); | ||
} | ||
} | ||
|
||
const op::FillMode SegmentMax::get_fill_mode() const { | ||
return m_fill_mode; | ||
} | ||
|
||
} // namespace v16 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.