From 111a8c0b998ec58c77172d1d4c781fe35ef259ab Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Fri, 3 Jan 2025 08:53:54 +0100 Subject: [PATCH] Add missing llm UTs --- src/BUILD | 6 +- src/test/http_openai_handler_test.cpp | 23 ++++++ src/test/llmnode_test.cpp | 18 ++++ src/test/mediapipe_disabled_test.cpp | 115 ++++++++++++++++++++++++++ src/test/metrics_flow_test.cpp | 31 +++---- 5 files changed, 176 insertions(+), 17 deletions(-) create mode 100644 src/test/mediapipe_disabled_test.cpp diff --git a/src/BUILD b/src/BUILD index 7ab5236791..84c5af2909 100644 --- a/src/BUILD +++ b/src/BUILD @@ -1829,7 +1829,6 @@ cc_test( "test/kfs_rest_test.cpp", "test/kfs_rest_parser_test.cpp", "test/layout_test.cpp", - "test/mediapipeflow_test.cpp", "test/metric_config_test.cpp", "test/metrics_test.cpp", "test/metrics_flow_test.cpp", @@ -1892,8 +1891,11 @@ cc_test( "test/get_mediapipe_graph_metadata_response_test.cpp", "test/mediapipe_framework_test.cpp", "test/http_openai_handler_test.cpp", + "test/mediapipeflow_test.cpp", + ], + "//:disable_mediapipe" : [ + "test/mediapipe_disabled_test.cpp", ], - "//:disable_mediapipe" : [], }) + select({ "//:not_disable_python": [ # OvmsPyTensor is currently not used in OVMS core and is just a base for the binding. diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 289219722f..3d32e83ff5 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -482,3 +482,26 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingMessagesEmptyContentArrayFails) { std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer); EXPECT_EQ(apiHandler->parseMessages(), absl::InvalidArgumentError("Invalid message structure - content array is empty")); } + +TEST_F(HttpOpenAIHandlerTest, V3ApiWithNonLLMCalculator) { + handler.reset(); + server.setShutdownRequest(1); + t->join(); + server.setShutdownRequest(0); + SetUpServer(getGenericFullPathForSrcTest("/ovms/src/test/mediapipe/config_mediapipe_dummy_kfs.json").c_str()); + ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpoint, headers), ovms::StatusCode::OK); + std::string requestBody = R"( + { + "model": "mediapipeDummyKFS", + "stream": false, + "messages": [] + } + )"; + + EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); + EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); + EXPECT_CALL(*writer, IsDisconnected()).Times(0); + + auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer); + ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_ADD_PACKET_INPUT_STREAM); + } diff --git a/src/test/llmnode_test.cpp b/src/test/llmnode_test.cpp index a66023a572..54b3f95e28 100644 --- a/src/test/llmnode_test.cpp +++ b/src/test/llmnode_test.cpp @@ -589,6 +589,24 @@ TEST_F(LLMFlowHttpTest, unaryChatCompletionsJsonN) { EXPECT_STREQ(parsedResponse["object"].GetString(), "chat.completion"); } +TEST_F(LLMFlowHttpTest, KFSApiRequestToChatCompletionsGraph) { + std::string requestBody = R"({ + "inputs" : [ + { + "name" : "input", + "shape" : [ 2, 2 ], + "datatype" : "UINT32", + "data" : [ 1, 2, 3, 4 ] + } + ] + })"; + std::vector> headers; + ASSERT_EQ(handler->parseRequestComponents(comp, "POST", "/v2/models/llmDummyKFS/versions/1/infer", headers), ovms::StatusCode::OK); + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer), + ovms::StatusCode::MEDIAPIPE_GRAPH_ADD_PACKET_INPUT_STREAM); +} + TEST_F(LLMFlowHttpTest, unaryChatCompletionsJson) { std::string requestBody = R"( { diff --git a/src/test/mediapipe_disabled_test.cpp b/src/test/mediapipe_disabled_test.cpp new file mode 100644 index 0000000000..9d616953ba --- /dev/null +++ b/src/test/mediapipe_disabled_test.cpp @@ -0,0 +1,115 @@ +//***************************************************************************** +// Copyright 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "../http_rest_api_handler.hpp" +#include "../http_status_code.hpp" +#include "../server.hpp" +#include "ov_utils.hpp" +#include "test_http_utils.hpp" +#include "test_utils.hpp" + +using namespace ovms; + +class MediapipeDisabledTest : public ::testing::Test { +protected: + static std::unique_ptr t; + +public: + std::unique_ptr handler; + + std::vector> headers; + ovms::HttpRequestComponents comp; + const std::string endpointChatCompletions = "/v3/chat/completions"; + const std::string endpointCompletions = "/v3/completions"; + std::shared_ptr writer; + ovms::HttpResponseComponents responseComponents; + std::string response; + std::vector expectedMessages; + + static void SetUpTestSuite() { + std::string port = "9173"; + ovms::Server& server = ovms::Server::instance(); + ::SetUpServer(t, server, port, getGenericFullPathForSrcTest("/ovms/src/test/configs/config_cpu_dummy.json").c_str()); + auto start = std::chrono::high_resolution_clock::now(); + const int numberOfRetries = 5; + while ((server.getModuleState(ovms::SERVABLE_MANAGER_MODULE_NAME) != ovms::ModuleState::INITIALIZED) && + (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start).count() < numberOfRetries)) { + } + } + + void SetUp() { + writer = std::make_shared(); + ON_CALL(*writer, PartialReplyBegin(::testing::_)).WillByDefault(testing::Invoke([](std::function fn) { fn(); })); // make the streaming flow sequential + ovms::Server& server = ovms::Server::instance(); + handler = std::make_unique(server, 5); + ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpointChatCompletions, headers), ovms::StatusCode::OK); + } + + static void TearDownTestSuite() { + ovms::Server& server = ovms::Server::instance(); + server.setShutdownRequest(1); + t->join(); + server.setShutdownRequest(0); + } + + void TearDown() { + handler.reset(); + } +}; +std::unique_ptr MediapipeDisabledTest::t; + +TEST_F(MediapipeDisabledTest, completionsRequest) { + std::string requestBody = R"( + { + "model": "dummy", + "stream": false, + "seed" : 1, + "best_of": 16, + "max_tokens": 5, + "prompt": "What is OpenVINO?" + } + )"; + + ASSERT_EQ( + handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, writer), + ovms::StatusCode::NOT_IMPLEMENTED); +} + +TEST_F(MediapipeDisabledTest, chatCompletionsRequest) { + std::string requestBody = R"( + { + "model": "dummy", + "stream": false, + "seed" : 1, + "best_of" : 16, + "n" : 8, + "max_tokens": 5, + "messages": [ + { + "role": "user", + "content": "What is OpenVINO?" + } + ] + } + )"; + + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer), + ovms::StatusCode::NOT_IMPLEMENTED); +} diff --git a/src/test/metrics_flow_test.cpp b/src/test/metrics_flow_test.cpp index ae3d6e4d6b..c501210e8e 100644 --- a/src/test/metrics_flow_test.cpp +++ b/src/test/metrics_flow_test.cpp @@ -515,16 +515,17 @@ TEST_F(MetricFlowTest, GrpcModelMetadata) { request.mutable_name()->assign(dagName); ASSERT_EQ(impl.ModelMetadata(nullptr, &request, &response).error_code(), grpc::StatusCode::OK); } +#if (MEDIAPIPE_DISABLE == 0) for (int i = 0; i < numberOfSuccessRequests; i++) { request.Clear(); response.Clear(); request.mutable_name()->assign(mpName); ASSERT_EQ(impl.ModelMetadata(nullptr, &request, &response).error_code(), grpc::StatusCode::OK); } - - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request +#endif + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request } TEST_F(MetricFlowTest, GrpcModelReady) { @@ -545,17 +546,17 @@ TEST_F(MetricFlowTest, GrpcModelReady) { request.mutable_name()->assign(dagName); ASSERT_EQ(impl.ModelReady(nullptr, &request, &response).error_code(), grpc::StatusCode::OK); } - +#if (MEDIAPIPE_DISABLE == 0) for (int i = 0; i < numberOfSuccessRequests; i++) { request.Clear(); response.Clear(); request.mutable_name()->assign(mpName); ASSERT_EQ(impl.ModelReady(nullptr, &request, &response).error_code(), grpc::StatusCode::OK); } - - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request +#endif + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request } TEST_F(MetricFlowTest, RestPredict) { @@ -767,16 +768,16 @@ TEST_F(MetricFlowTest, RestModelMetadata) { std::string request, response; ASSERT_EQ(handler.processModelMetadataKFSRequest(components, response, request), ovms::StatusCode::OK); } - +#if (MEDIAPIPE_DISABLE == 0) for (int i = 0; i < numberOfSuccessRequests; i++) { components.model_name = mpName; std::string request, response; ASSERT_EQ(handler.processModelMetadataKFSRequest(components, response, request), ovms::StatusCode::OK); } - - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request +#endif + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request } TEST_F(MetricFlowTest, ModelReady) { @@ -794,16 +795,16 @@ TEST_F(MetricFlowTest, ModelReady) { std::string request, response; ASSERT_EQ(handler.processModelReadyKFSRequest(components, response, request), ovms::StatusCode::OK); } - +#if (MEDIAPIPE_DISABLE == 0) for (int i = 0; i < numberOfSuccessRequests; i++) { components.model_name = mpName; std::string request, response; ASSERT_EQ(handler.processModelReadyKFSRequest(components, response, request), ovms::StatusCode::OK); } - - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request - checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request +#endif + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request + checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request } #if (MEDIAPIPE_DISABLE == 0)