Skip to content

Commit

Permalink
code style
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Feb 21, 2025
1 parent 04fa015 commit 12cf4ed
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,35 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from typing import Callable
import openvino

"""Postponed Constant is a way to materialize a big constant only when it is going to be serialized to IR and then immediately dispose."""


# `maker` is a function that returns ov.Tensor that represents a target Constant
def make_postponed_constant(element_type, shape, maker):
def make_postponed_constant(element_type: openvino.Type, shape: openvino.Shape, maker: Callable[[], openvino.Tensor]) -> openvino.Op:
class PostponedConstant(openvino.Op):
class_type_info = openvino.runtime.DiscreteTypeInfo("PostponedConstant", "extension")
class_type_info = openvino.DiscreteTypeInfo("PostponedConstant", "extension")

def __init__(self):
def __init__(self) -> None:
super().__init__(self)
self.get_rt_info()["postponed_constant"] = True # value doesn't matter
self.m_element_type = element_type
self.m_shape = shape
self.constructor_validate_and_infer_types()

def get_type_info(self):
def get_type_info(self) -> openvino.DiscreteTypeInfo:
return PostponedConstant.class_type_info

def evaluate(self, outputs, _):
def evaluate(self, outputs: list[openvino.Tensor], _: list[openvino.Tensor]) -> bool:
maker().copy_to(outputs[0])
return True

def clone_with_new_inputs(self, _):
def clone_with_new_inputs(self, _: list[openvino.Tensor]) -> openvino.Op:
return PostponedConstant()

def validate_and_infer_types(self):
def validate_and_infer_types(self) -> None:
self.set_output_type(0, self.m_element_type, openvino.PartialShape(self.m_shape))

return PostponedConstant()
19 changes: 9 additions & 10 deletions src/core/tests/pass/serialization/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "common_test_utils/common_utils.hpp"
#include "common_test_utils/file_utils.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/serialize.hpp"
#include "openvino/runtime/core.hpp"
#include "openvino/op/constant.hpp"

class CustomOpsSerializationTest : public ::testing::Test {
protected:
Expand Down Expand Up @@ -117,7 +117,6 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
ASSERT_TRUE(success) << message;
}


class PostponedOp : public ov::op::Op {
public:
ov::element::Type m_type;
Expand All @@ -144,28 +143,28 @@ class PostponedOp : public ov::op::Op {
}

MOCK_METHOD(bool,
evaluate,
(ov::TensorVector & output_values, const ov::TensorVector& input_values),
(const, override));
evaluate,
(ov::TensorVector & output_values, const ov::TensorVector& input_values),
(const, override));
};

TEST(PostponedOpSerializationTest, CorrectRtInfo) {
auto constant = std::make_shared<PostponedOp>(ov::element::f16, ov::Shape{1,2,3,4});
auto constant = std::make_shared<PostponedOp>(ov::element::f16, ov::Shape{1, 2, 3, 4});
constant->get_rt_info()["postponed_constant"] = true;
auto model = std::make_shared<ov::Model>(ov::OutputVector{constant});

EXPECT_CALL(*constant, evaluate).Times(1);

std::stringstream serialized_model, serialized_weigths;
ov::pass::Serialize(serialized_model, serialized_weigths).run_on_model(model);
}

TEST(PostponedOpSerializationTest, IncorrectRtInfo) {
auto constant = std::make_shared<PostponedOp>(ov::element::f16, ov::Shape{1,2,3,4});
auto constant = std::make_shared<PostponedOp>(ov::element::f16, ov::Shape{1, 2, 3, 4});
auto model = std::make_shared<ov::Model>(ov::OutputVector{constant});

EXPECT_CALL(*constant, evaluate).Times(0);

std::stringstream serialized_model, serialized_weigths;
ov::pass::Serialize(serialized_model, serialized_weigths).run_on_model(model);
}

0 comments on commit 12cf4ed

Please sign in to comment.