From 01e021d36db8c9b11639fec822dc6b9739261cc2 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 <117025882+Mohamed-Ashraf273@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:10:07 +0200 Subject: [PATCH 1/4] [PT FE] add support of shift operations --- src/frontends/pytorch/src/op/bitwise.cpp | 36 +++++- src/frontends/pytorch/src/op_table.cpp | 6 + .../pytorch_tests/test_shiftOperations.py | 108 ++++++++++++++++++ 3 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 tests/layer_tests/pytorch_tests/test_shiftOperations.py diff --git a/src/frontends/pytorch/src/op/bitwise.cpp b/src/frontends/pytorch/src/op/bitwise.cpp index 03ae5de900ecf9..795db868219ab3 100644 --- a/src/frontends/pytorch/src/op/bitwise.cpp +++ b/src/frontends/pytorch/src/op/bitwise.cpp @@ -7,6 +7,8 @@ #include "openvino/op/bitwise_not.hpp" #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "openvino/op/bitwise_left_shift.hpp" +#include "openvino/op/bitwise_right_shift.hpp" #include "openvino/op/convert_like.hpp" #include "utils.hpp" @@ -73,7 +75,39 @@ OutputVector translate_bitwise_xor(const NodeContext& context) { return {xor_x}; }; +OutputVector translate_bitwise_left_shift(const NodeContext& context) { + num_inputs_check(context, 2, 3); + Output x; + Output y; + std::tie(x, y) = get_inputs_with_promoted_types(context, 0, 1); + auto lshift = context.mark_node(std::make_shared(x, y))->output(0); + if (!context.input_is_none(2)) { + auto out = context.get_input(2); + if (out.get_element_type().is_dynamic() || lshift.get_element_type() != out.get_element_type()) { + lshift = context.mark_node(std::make_shared(lshift, out)); + } + context.mutate_input(2, lshift); + } + return {lshift}; +}; + +OutputVector translate_bitwise_right_shift(const NodeContext& context) { + num_inputs_check(context, 2, 3); + Output x; + Output y; + std::tie(x, y) = get_inputs_with_promoted_types(context, 0, 1); + auto rshift = context.mark_node(std::make_shared(x, y))->output(0); + if (!context.input_is_none(2)) { + auto out = context.get_input(2); + if (out.get_element_type().is_dynamic() || rshift.get_element_type() != out.get_element_type()) { + rshift = context.mark_node(std::make_shared(rshift, out)); + } + context.mutate_input(2, rshift); + } + return {rshift}; +} + } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 2d33b32472ba36..ba35a3967f6480 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -53,6 +53,8 @@ OP_CONVERTER(translate_bitwise_and); OP_CONVERTER(translate_bitwise_not); OP_CONVERTER(translate_bitwise_or); OP_CONVERTER(translate_bitwise_xor); +OP_CONVERTER(translate_bitwise_left_shift); +OP_CONVERTER(translate_bitwise_right_shift); OP_CONVERTER(translate_bucketize); OP_CONVERTER(translate_cat); OP_CONVERTER(translate_cdist); @@ -343,6 +345,10 @@ const std::unordered_map get_supported_ops_ts() { return { {"aten::__and__", op::translate_bitwise_and}, {"aten::__iand__", op::inplace_op}, + {"aten::__lshift__", op::translate_bitwise_left_shift}, + {"aten::__rshift__", op::translate_bitwise_right_shift}, + {"aten::bitwise_left_shift", op::translate_bitwise_left_shift}, + {"aten::bitwise_right_shift", op::translate_bitwise_right_shift}, {"aten::__derive_index", op::translate_derive_index}, {"aten::__getitem__", op::translate_getitem}, {"aten::__not__", op::translate_1to1_match_1_inputs}, diff --git a/tests/layer_tests/pytorch_tests/test_shiftOperations.py b/tests/layer_tests/pytorch_tests/test_shiftOperations.py new file mode 100644 index 00000000000000..ed1519dc0413ec --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_shiftOperations.py @@ -0,0 +1,108 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from pytorch_layer_test_class import PytorchLayerTest, skip_if_export + + +class TestShiftOperators(PytorchLayerTest): + def _prepare_input(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape): + choices = np.array([1, 2, 4, 8, 16, 32]) + shifts = np.array([0, 1, 2, 3, 4, 5]) + + x = np.random.choice(choices, lhs_shape).astype(lhs_dtype) + y = np.random.choice(shifts, rhs_shape).astype(rhs_dtype) + return x, y + + def create_model(self): + class aten_shift(torch.nn.Module): + def forward(self, lhs, rhs): + return lhs << rhs, lhs >> rhs + + ref_net = None + return aten_shift(), ref_net, ("aten::__lshift__", "aten::__rshift__") + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize("lhs_dtype", ["int32", "int64"]) + @pytest.mark.parametrize("rhs_dtype", ["int32", "int64"]) + @pytest.mark.parametrize( + ("lhs_shape", "rhs_shape"), + [ + ([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ([], []), + ], + ) + def test_shift_operators(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape, ie_device, precision, ir_version): + self._test( + *self.create_model(), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={ + "lhs_dtype": lhs_dtype, + "rhs_dtype": rhs_dtype, + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + }, + trace_model=True, + freeze_model=False, + ) + + +class TestBitwiseShiftFunctions(PytorchLayerTest): + def _prepare_input(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape): + choices = np.array([1, 2, 4, 8, 16, 32]) + shifts = np.array([0, 1, 2, 3, 4, 5]) + + x = np.random.choice(choices, lhs_shape).astype(lhs_dtype) + y = np.random.choice(shifts, rhs_shape).astype(rhs_dtype) + return x, y + + def create_model(self): + class aten_bitwise_shift(torch.nn.Module): + def forward(self, lhs, rhs): + return ( + torch.bitwise_left_shift(lhs, rhs), + torch.bitwise_right_shift(lhs, rhs) + ) + + ref_net = None + return aten_bitwise_shift(), ref_net, ("aten::bitwise_left_shift", "aten::bitwise_right_shift") + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize("lhs_dtype", ["int32", "int64"]) + @pytest.mark.parametrize("rhs_dtype", ["int32", "int64"]) + @pytest.mark.parametrize( + ("lhs_shape", "rhs_shape"), + [ + ([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ([], []), + ], + ) + def test_bitwise_shift_functions(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape, ie_device, precision, ir_version): + self._test( + *self.create_model(), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={ + "lhs_dtype": lhs_dtype, + "rhs_dtype": rhs_dtype, + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + }, + trace_model=True, + freeze_model=False, + ) \ No newline at end of file From 03f7ea187c7fdc2e466f32dcdcad2389c87f81cb Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 <117025882+Mohamed-Ashraf273@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:28:42 +0200 Subject: [PATCH 2/4] Update op_table.cpp --- src/frontends/pytorch/src/op_table.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ba35a3967f6480..ae5c37ddc4e98d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -347,8 +347,6 @@ const std::unordered_map get_supported_ops_ts() { {"aten::__iand__", op::inplace_op}, {"aten::__lshift__", op::translate_bitwise_left_shift}, {"aten::__rshift__", op::translate_bitwise_right_shift}, - {"aten::bitwise_left_shift", op::translate_bitwise_left_shift}, - {"aten::bitwise_right_shift", op::translate_bitwise_right_shift}, {"aten::__derive_index", op::translate_derive_index}, {"aten::__getitem__", op::translate_getitem}, {"aten::__not__", op::translate_1to1_match_1_inputs}, @@ -417,8 +415,10 @@ const std::unordered_map get_supported_ops_ts() { {"aten::batch_norm", op::translate_batch_norm}, {"aten::bernoulli", op::translate_bernoulli}, {"aten::bitwise_and", op::translate_bitwise_and}, + {"aten::bitwise_left_shift", op::translate_bitwise_left_shift}, {"aten::bitwise_not", op::translate_bitwise_not}, {"aten::bitwise_or", op::translate_bitwise_or}, + {"aten::bitwise_right_shift", op::translate_bitwise_right_shift}, {"aten::bitwise_xor", op::translate_bitwise_xor}, {"aten::bmm", op::translate_1to1_match_2_inputs}, {"aten::Bool", op::translate_bool}, From e38f0ccebb51c62c9220313d950303e2d62d69db Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 <117025882+Mohamed-Ashraf273@users.noreply.github.com> Date: Thu, 20 Feb 2025 20:09:21 +0200 Subject: [PATCH 3/4] Update bitwise.cpp --- src/frontends/pytorch/src/op/bitwise.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/frontends/pytorch/src/op/bitwise.cpp b/src/frontends/pytorch/src/op/bitwise.cpp index 795db868219ab3..cf20a415f5f893 100644 --- a/src/frontends/pytorch/src/op/bitwise.cpp +++ b/src/frontends/pytorch/src/op/bitwise.cpp @@ -4,11 +4,11 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/bitwise_and.hpp" +#include "openvino/op/bitwise_left_shift.hpp" #include "openvino/op/bitwise_not.hpp" #include "openvino/op/bitwise_or.hpp" -#include "openvino/op/bitwise_xor.hpp" -#include "openvino/op/bitwise_left_shift.hpp" #include "openvino/op/bitwise_right_shift.hpp" +#include "openvino/op/bitwise_xor.hpp" #include "openvino/op/convert_like.hpp" #include "utils.hpp" @@ -110,4 +110,4 @@ OutputVector translate_bitwise_right_shift(const NodeContext& context) { } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov From 68eccc2e8d1d3d903be14426c366dfa9d8c8a7bb Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 <117025882+Mohamed-Ashraf273@users.noreply.github.com> Date: Fri, 21 Feb 2025 01:40:04 +0200 Subject: [PATCH 4/4] Update op_table.cpp --- src/frontends/pytorch/src/op_table.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ae5c37ddc4e98d..ee28ea5b38f859 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -50,11 +50,11 @@ OP_CONVERTER(translate_bool); OP_CONVERTER(translate_batch_norm); OP_CONVERTER(translate_bernoulli); OP_CONVERTER(translate_bitwise_and); +OP_CONVERTER(translate_bitwise_left_shift); OP_CONVERTER(translate_bitwise_not); OP_CONVERTER(translate_bitwise_or); -OP_CONVERTER(translate_bitwise_xor); -OP_CONVERTER(translate_bitwise_left_shift); OP_CONVERTER(translate_bitwise_right_shift); +OP_CONVERTER(translate_bitwise_xor); OP_CONVERTER(translate_bucketize); OP_CONVERTER(translate_cat); OP_CONVERTER(translate_cdist);