Skip to content

Commit

Permalink
Add missing llm UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkulakowski committed Jan 8, 2025
1 parent d877288 commit 111a8c0
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,6 @@ cc_test(
"test/kfs_rest_test.cpp",
"test/kfs_rest_parser_test.cpp",
"test/layout_test.cpp",
"test/mediapipeflow_test.cpp",
"test/metric_config_test.cpp",
"test/metrics_test.cpp",
"test/metrics_flow_test.cpp",
Expand Down Expand Up @@ -1892,8 +1891,11 @@ cc_test(
"test/get_mediapipe_graph_metadata_response_test.cpp",
"test/mediapipe_framework_test.cpp",
"test/http_openai_handler_test.cpp",
"test/mediapipeflow_test.cpp",
],
"//:disable_mediapipe" : [
"test/mediapipe_disabled_test.cpp",
],
"//:disable_mediapipe" : [],
}) + select({
"//:not_disable_python": [
# OvmsPyTensor is currently not used in OVMS core and is just a base for the binding.
Expand Down
23 changes: 23 additions & 0 deletions src/test/http_openai_handler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,26 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingMessagesEmptyContentArrayFails) {
std::shared_ptr<ovms::OpenAIChatCompletionsHandler> apiHandler = std::make_shared<ovms::OpenAIChatCompletionsHandler>(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer);
EXPECT_EQ(apiHandler->parseMessages(), absl::InvalidArgumentError("Invalid message structure - content array is empty"));
}

TEST_F(HttpOpenAIHandlerTest, V3ApiWithNonLLMCalculator) {
handler.reset();
server.setShutdownRequest(1);
t->join();
server.setShutdownRequest(0);
SetUpServer(getGenericFullPathForSrcTest("/ovms/src/test/mediapipe/config_mediapipe_dummy_kfs.json").c_str());
ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpoint, headers), ovms::StatusCode::OK);
std::string requestBody = R"(
{
"model": "mediapipeDummyKFS",
"stream": false,
"messages": []
}
)";

EXPECT_CALL(*writer, PartialReplyEnd()).Times(0);
EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0);
EXPECT_CALL(*writer, IsDisconnected()).Times(0);

auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_ADD_PACKET_INPUT_STREAM);
}
18 changes: 18 additions & 0 deletions src/test/llmnode_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,24 @@ TEST_F(LLMFlowHttpTest, unaryChatCompletionsJsonN) {
EXPECT_STREQ(parsedResponse["object"].GetString(), "chat.completion");
}

TEST_F(LLMFlowHttpTest, KFSApiRequestToChatCompletionsGraph) {
std::string requestBody = R"({
"inputs" : [
{
"name" : "input",
"shape" : [ 2, 2 ],
"datatype" : "UINT32",
"data" : [ 1, 2, 3, 4 ]
}
]
})";
std::vector<std::pair<std::string, std::string>> headers;
ASSERT_EQ(handler->parseRequestComponents(comp, "POST", "/v2/models/llmDummyKFS/versions/1/infer", headers), ovms::StatusCode::OK);
ASSERT_EQ(
handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer),
ovms::StatusCode::MEDIAPIPE_GRAPH_ADD_PACKET_INPUT_STREAM);
}

TEST_F(LLMFlowHttpTest, unaryChatCompletionsJson) {
std::string requestBody = R"(
{
Expand Down
115 changes: 115 additions & 0 deletions src/test/mediapipe_disabled_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//*****************************************************************************
// Copyright 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "../http_rest_api_handler.hpp"
#include "../http_status_code.hpp"
#include "../server.hpp"
#include "ov_utils.hpp"
#include "test_http_utils.hpp"
#include "test_utils.hpp"

using namespace ovms;

class MediapipeDisabledTest : public ::testing::Test {
protected:
static std::unique_ptr<std::thread> t;

public:
std::unique_ptr<ovms::HttpRestApiHandler> handler;

std::vector<std::pair<std::string, std::string>> headers;
ovms::HttpRequestComponents comp;
const std::string endpointChatCompletions = "/v3/chat/completions";
const std::string endpointCompletions = "/v3/completions";
std::shared_ptr<MockedServerRequestInterface> writer;
ovms::HttpResponseComponents responseComponents;
std::string response;
std::vector<std::string> expectedMessages;

static void SetUpTestSuite() {
std::string port = "9173";
ovms::Server& server = ovms::Server::instance();
::SetUpServer(t, server, port, getGenericFullPathForSrcTest("/ovms/src/test/configs/config_cpu_dummy.json").c_str());
auto start = std::chrono::high_resolution_clock::now();
const int numberOfRetries = 5;
while ((server.getModuleState(ovms::SERVABLE_MANAGER_MODULE_NAME) != ovms::ModuleState::INITIALIZED) &&
(std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start).count() < numberOfRetries)) {
}
}

void SetUp() {
writer = std::make_shared<MockedServerRequestInterface>();
ON_CALL(*writer, PartialReplyBegin(::testing::_)).WillByDefault(testing::Invoke([](std::function<void()> fn) { fn(); })); // make the streaming flow sequential
ovms::Server& server = ovms::Server::instance();
handler = std::make_unique<ovms::HttpRestApiHandler>(server, 5);
ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpointChatCompletions, headers), ovms::StatusCode::OK);
}

static void TearDownTestSuite() {
ovms::Server& server = ovms::Server::instance();
server.setShutdownRequest(1);
t->join();
server.setShutdownRequest(0);
}

void TearDown() {
handler.reset();
}
};
std::unique_ptr<std::thread> MediapipeDisabledTest::t;

TEST_F(MediapipeDisabledTest, completionsRequest) {
std::string requestBody = R"(
{
"model": "dummy",
"stream": false,
"seed" : 1,
"best_of": 16,
"max_tokens": 5,
"prompt": "What is OpenVINO?"
}
)";

ASSERT_EQ(
handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, writer),
ovms::StatusCode::NOT_IMPLEMENTED);
}

TEST_F(MediapipeDisabledTest, chatCompletionsRequest) {
std::string requestBody = R"(
{
"model": "dummy",
"stream": false,
"seed" : 1,
"best_of" : 16,
"n" : 8,
"max_tokens": 5,
"messages": [
{
"role": "user",
"content": "What is OpenVINO?"
}
]
}
)";

ASSERT_EQ(
handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer),
ovms::StatusCode::NOT_IMPLEMENTED);
}
31 changes: 16 additions & 15 deletions src/test/metrics_flow_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,16 +515,17 @@ TEST_F(MetricFlowTest, GrpcModelMetadata) {
request.mutable_name()->assign(dagName);
ASSERT_EQ(impl.ModelMetadata(nullptr, &request, &response).error_code(), grpc::StatusCode::OK);
}
#if (MEDIAPIPE_DISABLE == 0)
for (int i = 0; i < numberOfSuccessRequests; i++) {
request.Clear();
response.Clear();
request.mutable_name()->assign(mpName);
ASSERT_EQ(impl.ModelMetadata(nullptr, &request, &response).error_code(), grpc::StatusCode::OK);
}

checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
#endif
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
}

TEST_F(MetricFlowTest, GrpcModelReady) {
Expand All @@ -545,17 +546,17 @@ TEST_F(MetricFlowTest, GrpcModelReady) {
request.mutable_name()->assign(dagName);
ASSERT_EQ(impl.ModelReady(nullptr, &request, &response).error_code(), grpc::StatusCode::OK);
}

#if (MEDIAPIPE_DISABLE == 0)
for (int i = 0; i < numberOfSuccessRequests; i++) {
request.Clear();
response.Clear();
request.mutable_name()->assign(mpName);
ASSERT_EQ(impl.ModelReady(nullptr, &request, &response).error_code(), grpc::StatusCode::OK);
}

checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
#endif
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "gRPC", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
}

TEST_F(MetricFlowTest, RestPredict) {
Expand Down Expand Up @@ -767,16 +768,16 @@ TEST_F(MetricFlowTest, RestModelMetadata) {
std::string request, response;
ASSERT_EQ(handler.processModelMetadataKFSRequest(components, response, request), ovms::StatusCode::OK);
}

#if (MEDIAPIPE_DISABLE == 0)
for (int i = 0; i < numberOfSuccessRequests; i++) {
components.model_name = mpName;
std::string request, response;
ASSERT_EQ(handler.processModelMetadataKFSRequest(components, response, request), ovms::StatusCode::OK);
}

checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
#endif
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelMetadata", "KServe", numberOfSuccessRequests); // ran by real request
}

TEST_F(MetricFlowTest, ModelReady) {
Expand All @@ -794,16 +795,16 @@ TEST_F(MetricFlowTest, ModelReady) {
std::string request, response;
ASSERT_EQ(handler.processModelReadyKFSRequest(components, response, request), ovms::StatusCode::OK);
}

#if (MEDIAPIPE_DISABLE == 0)
for (int i = 0; i < numberOfSuccessRequests; i++) {
components.model_name = mpName;
std::string request, response;
ASSERT_EQ(handler.processModelReadyKFSRequest(components, response, request), ovms::StatusCode::OK);
}

checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkMediapipeRequestsCounterMetadataReady(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, mpName, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
#endif
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, modelName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
checkRequestsCounter(server.collect(), METRIC_NAME_REQUESTS_SUCCESS, dagName, 1, "REST", "ModelReady", "KServe", numberOfSuccessRequests); // ran by real request
}

#if (MEDIAPIPE_DISABLE == 0)
Expand Down

0 comments on commit 111a8c0

Please sign in to comment.