Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyOV] Extend Python API with SegmentMax-16 #28999

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bindings/python/src/openvino/opset16/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# New operations added in Opset16
from openvino.opset16.ops import identity
from openvino.opset16.ops import segment_max

# Operators from previous opsets
# TODO (ticket: 156877): Add previous opset operators at the end of opset16 development
32 changes: 31 additions & 1 deletion src/bindings/python/src/openvino/opset16/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""Factory functions for ops added to openvino opset16."""
from functools import partial
from typing import Optional
from typing import Optional, Union

from openvino import Node
from openvino.utils.decorators import nameable_op
Expand Down Expand Up @@ -32,3 +32,33 @@ def identity(
as_nodes(data, name=name),
{},
)


@nameable_op
def segment_max(
data: NodeInput,
segment_ids: NodeInput,
num_segments: Optional[NodeInput] = None,
fill_mode: Union[str, None] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fill_mode: Union[str, None] = None,
fill_mode: Optional[str] = None,

not sure why we need union here

name: Optional[str] = None,
) -> Node:
"""The SegmentMax operation finds the maximum value in each specified segment of the input tensor.

:param data: ND tensor of type T, the numerical data on which SegmentMax operation will be performed.
:param segment_ids: 1D Tensor of sorted non-negative numbers, representing the segments.
:param num_segments: An optional scalar value representing the segments count. If not provided, it is inferred from segment_ids.
:param fill_mode: Responsible for the value assigned to segments which are empty. Can be "ZERO" or "LOWEST".
:param name: Optional name for the node.

:return: The new node performing SegmentMax operation.
"""
if fill_mode is None:
raise ValueError("fill_mode must be provided and can be either 'ZERO' or 'LOWEST'")
inputs = [data, segment_ids]
if num_segments is not None:
inputs.append(num_segments)
return _get_node_factory_opset16().create(
"SegmentMax",
as_nodes(*inputs, name=name),
{"fill_mode": fill_mode},
)
53 changes: 53 additions & 0 deletions src/bindings/python/tests/test_graph/test_segment_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from openvino import Type, PartialShape, Dimension
from openvino.opset15 import parameter
import openvino.opset16 as ov
import pytest


@pytest.mark.parametrize("dtype_segment_ids", [
Type.i32,
Type.i64
])
@pytest.mark.parametrize("dtype_num_segments", [
Type.i32,
Type.i64
])
@pytest.mark.parametrize(("data_shape", "segment_ids_shape"), [
((4,), (4,)),
((1, 3, 4), (1,)),
((3, 1, 2, 5), (3,))
])
def test_segment_max_with_num_segments(dtype_segment_ids, dtype_num_segments, data_shape, segment_ids_shape):
data = parameter(data_shape, name="data", dtype=Type.f32)
segment_ids = parameter(segment_ids_shape, name="segment_ids", dtype=dtype_segment_ids)
num_segments = parameter((), name="num_segments", dtype=dtype_num_segments)
node = ov.segment_max(data, segment_ids, num_segments, fill_mode="ZERO")

assert node.get_type_name() == "SegmentMax"
assert node.get_output_size() == 1
assert node.get_output_element_type(0) == Type.f32
assert node.get_output_shape(0) == PartialShape([Dimension.dynamic(), *data_shape[1:]])


@pytest.mark.parametrize("dtype_segment_ids", [
Type.i32,
Type.i64
])
@pytest.mark.parametrize(("data_shape", "segment_ids_shape"), [
((4,), (4,)),
((1, 3, 4), (1,)),
((3, 1, 2, 5), (3,))
])
def test_segment_max_without_num_segments(dtype_segment_ids, data_shape, segment_ids_shape):
data = parameter(data_shape, name="data", dtype=Type.f32)
segment_ids = parameter(segment_ids_shape, name="segment_ids", dtype=dtype_segment_ids)
node = ov.segment_max(data, segment_ids, fill_mode="LOWEST")

assert node.get_type_name() == "SegmentMax"
assert node.get_output_size() == 1
assert node.get_output_element_type(0) == Type.f32
assert node.get_output_shape(0) == PartialShape([Dimension.dynamic(), *data_shape[1:]])
Loading