From 48567be2b7334bc3ad6e495ad29af12fee50d7eb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 20 Feb 2024 13:33:25 +0100 Subject: [PATCH] Fix tests, add mock services for tests --- .../inference/InferenceBaseRestTest.java | 39 ++++++++--- .../xpack/inference/InferenceCrudIT.java | 12 ++-- .../MockDenseInferenceServiceIT.java | 65 +++++++++++++++++++ ...java => MockSparseInferenceServiceIT.java} | 12 ++-- .../mock/AbstractTestInferenceService.java | 6 +- .../TestDenseInferenceServiceExtension.java | 5 ++ .../TestSparseInferenceServiceExtension.java | 7 +- 7 files changed, 122 insertions(+), 24 deletions(-) create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java rename x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/{MockInferenceServiceIT.java => MockSparseInferenceServiceIT.java} (88%) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 11a5bdf045f21..a9096f9059c5b 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -50,11 +50,11 @@ protected Settings restClientSettings() { return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); } - static String mockServiceModelConfig() { - return mockServiceModelConfig(null); + static String mockSparseServiceModelConfig() { + return mockSparseServiceModelConfig(null); } - static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) { + static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" { @@ -72,7 +72,7 @@ static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) { """, taskType); } - static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) { + static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" { @@ -91,6 +91,22 @@ static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean """, taskType, shouldReturnHiddenField); } + static String mockDenseServiceModelConfig() { + return """ + { + "task_type": "text_embedding", + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_dense_vector_model", + "api_key": "abc64", + "dimensions": 246 + }, + "task_settings": { + } + } + """; + } + protected void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); @@ -200,11 +216,16 @@ private Map inferOnMockServiceInternal(String endpoint, List resultMap, int expectedNumberOfResults, TaskType taskType) { - if (taskType == TaskType.SPARSE_EMBEDDING) { - var results = (List>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString()); - assertThat(results, hasSize(expectedNumberOfResults)); - } else { - fail("test with task type [" + taskType + "] are not supported yet"); + switch (taskType) { + case SPARSE_EMBEDDING -> { + var results = (List>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString()); + assertThat(results, hasSize(expectedNumberOfResults)); + } + case TEXT_EMBEDDING -> { + var results = (List>) resultMap.get(TaskType.TEXT_EMBEDDING.toString()); + assertThat(results, hasSize(expectedNumberOfResults)); + } + default -> fail("test with task type [" + taskType + "] are not supported yet"); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f6718afd2f879..1ecc7980cea99 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -25,10 +25,10 @@ public class InferenceCrudIT extends InferenceBaseRestTest { @SuppressWarnings("unchecked") public void testGet() throws IOException { for (int i = 0; i < 5; i++) { - putModel("se_model_" + i, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); } for (int i = 0; i < 4; i++) { - putModel("te_model_" + i, mockServiceModelConfig(), TaskType.TEXT_EMBEDDING); + putModel("te_model_" + i, mockSparseServiceModelConfig(), TaskType.TEXT_EMBEDDING); } var getAllModels = (List>) getAllModels().get("models"); @@ -59,7 +59,7 @@ public void testGet() throws IOException { } public void testGetModelWithWrongTaskType() throws IOException { - putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING)); assertThat( e.getMessage(), @@ -68,7 +68,7 @@ public void testGetModelWithWrongTaskType() throws IOException { } public void testDeleteModelWithWrongTaskType() throws IOException { - putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING)); assertThat( e.getMessage(), @@ -79,7 +79,7 @@ public void testDeleteModelWithWrongTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetModelWithAnyTaskType() throws IOException { String inferenceEntityId = "sparse_embedding_model"; - putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var singleModel = (List>) getModels(inferenceEntityId, TaskType.ANY).get("models"); assertEquals(inferenceEntityId, singleModel.get(0).get("model_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type")); @@ -88,7 +88,7 @@ public void testGetModelWithAnyTaskType() throws IOException { @SuppressWarnings("unchecked") public void testApisWithoutTaskType() throws IOException { String modelId = "no_task_type_in_url"; - putModel(modelId, mockServiceModelConfig(TaskType.SPARSE_EMBEDDING)); + putModel(modelId, mockSparseServiceModelConfig(TaskType.SPARSE_EMBEDDING)); var singleModel = (List>) getModel(modelId).get("models"); assertEquals(modelId, singleModel.get(0).get("model_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type")); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java new file mode 100644 index 0000000000000..a8c0a45f3f9db --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class MockDenseInferenceServiceIT extends InferenceBaseRestTest { + + @SuppressWarnings("unchecked") + public void testMockService() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); + var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING); + var model = ((List>) getModels.get("models")).get(0); + + for (var modelMap : List.of(putModel, model)) { + assertEquals(inferenceEntityId, modelMap.get("model_id")); + assertEquals(TaskType.TEXT_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type"))); + assertEquals("text_embedding_test_service", modelMap.get("service")); + } + + // The response is randomly generated, the input can be anything + var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10))); + assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING); + } + + public void testMockServiceWithMultipleInputs() throws IOException { + String inferenceEntityId = "test-mock-with-multi-inputs"; + putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); + + // The response is randomly generated, the input can be anything + var inference = inferOnMockService( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) + ); + + assertNonEmptyInferenceResults(inference, 3, TaskType.TEXT_EMBEDDING); + } + + @SuppressWarnings("unchecked") + public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); + var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING); + var model = ((List>) getModels.get("models")).get(0); + + var serviceSettings = (Map) model.get("service_settings"); + assertNull(serviceSettings.get("api_key")); + assertNotNull(serviceSettings.get("model")); + + var putServiceSettings = (Map) putModel.get("service_settings"); + assertNull(putServiceSettings.get("api_key")); + assertNotNull(putServiceSettings.get("model")); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java similarity index 88% rename from x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java rename to x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java index c226612d7a6e5..616947eae4d72 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java @@ -15,12 +15,12 @@ import static org.hamcrest.Matchers.is; -public class MockInferenceServiceIT extends InferenceBaseRestTest { +public class MockSparseInferenceServiceIT extends InferenceBaseRestTest { @SuppressWarnings("unchecked") public void testMockService() throws IOException { String inferenceEntityId = "test-mock"; - var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING); var model = ((List>) getModels.get("models")).get(0); @@ -37,7 +37,7 @@ public void testMockService() throws IOException { public void testMockServiceWithMultipleInputs() throws IOException { String inferenceEntityId = "test-mock-with-multi-inputs"; - putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); // The response is randomly generated, the input can be anything var inference = inferOnMockService( @@ -52,7 +52,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { @SuppressWarnings("unchecked") public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException { String inferenceEntityId = "test-mock"; - var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING); var model = ((List>) getModels.get("models")).get(0); @@ -68,7 +68,7 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti @SuppressWarnings("unchecked") public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws IOException { String inferenceEntityId = "test-mock"; - var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING); var model = ((List>) getModels.get("models")).get(0); @@ -87,7 +87,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I @SuppressWarnings("unchecked") public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOException { String inferenceEntityId = "test-mock"; - var putModel = putModel(inferenceEntityId, mockServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING); + var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING); var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING); var model = ((List>) getModels.get("models")).get(0); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index e2643bce5206a..99dfc9582eb05 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -56,7 +56,7 @@ public TestServiceModel parsePersistedConfigWithSecrets( var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); - var serviceSettings = TestDenseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap); + var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap); var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); var taskSettingsMap = getTaskSettingsMap(config); @@ -70,7 +70,7 @@ public TestServiceModel parsePersistedConfigWithSecrets( public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var serviceSettings = TestDenseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap); + var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap); var taskSettingsMap = getTaskSettingsMap(config); var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); @@ -78,6 +78,8 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap); + @Override public void start(Model model, ActionListener listener) { listener.onResponse(true); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 57c6b7a85230d..54fe6e01946b4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -140,6 +140,10 @@ private List makeChunkedResults(List inp } return results; } + + protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { + return TestServiceSettings.fromMap(serviceSettingsMap); + } } public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings { @@ -214,6 +218,7 @@ public ToXContentObject getFilteredXContentObject() { return builder; }; } + } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index b2006f4c4b5f4..e5020774a70f3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -58,7 +58,7 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platfromArchitectures, + Set platformArchitectures, ActionListener parsedModelListener ) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); @@ -133,6 +133,11 @@ private List makeChunkedResults(List inp } return List.of(new ChunkedSparseEmbeddingResults(chunks)); } + + protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { + return TestServiceSettings.fromMap(serviceSettingsMap); + } + } public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {