Skip to content

Commit

Permalink
Fix tests, add mock services for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 20, 2024
1 parent ea48aea commit 48567be
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

static String mockServiceModelConfig() {
return mockServiceModelConfig(null);
static String mockSparseServiceModelConfig() {
return mockSparseServiceModelConfig(null);
}

static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) {
static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
Expand All @@ -72,7 +72,7 @@ static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) {
""", taskType);
}

static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
Expand All @@ -91,6 +91,22 @@ static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean
""", taskType, shouldReturnHiddenField);
}

static String mockDenseServiceModelConfig() {
return """
{
"task_type": "text_embedding",
"service": "text_embedding_test_service",
"service_settings": {
"model": "my_dense_vector_model",
"api_key": "abc64",
"dimensions": 246
},
"task_settings": {
}
}
""";
}

protected void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
Expand Down Expand Up @@ -200,11 +216,16 @@ private Map<String, Object> inferOnMockServiceInternal(String endpoint, List<Str

@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
if (taskType == TaskType.SPARSE_EMBEDDING) {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
} else {
fail("test with task type [" + taskType + "] are not supported yet");
switch (taskType) {
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case TEXT_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.TEXT_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
default -> fail("test with task type [" + taskType + "] are not supported yet");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
@SuppressWarnings("unchecked")
public void testGet() throws IOException {
for (int i = 0; i < 5; i++) {
putModel("se_model_" + i, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
}
for (int i = 0; i < 4; i++) {
putModel("te_model_" + i, mockServiceModelConfig(), TaskType.TEXT_EMBEDDING);
putModel("te_model_" + i, mockSparseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}

var getAllModels = (List<Map<String, Object>>) getAllModels().get("models");
Expand Down Expand Up @@ -59,7 +59,7 @@ public void testGet() throws IOException {
}

public void testGetModelWithWrongTaskType() throws IOException {
putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
assertThat(
e.getMessage(),
Expand All @@ -68,7 +68,7 @@ public void testGetModelWithWrongTaskType() throws IOException {
}

public void testDeleteModelWithWrongTaskType() throws IOException {
putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
assertThat(
e.getMessage(),
Expand All @@ -79,7 +79,7 @@ public void testDeleteModelWithWrongTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetModelWithAnyTaskType() throws IOException {
String inferenceEntityId = "sparse_embedding_model";
putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var singleModel = (List<Map<String, Object>>) getModels(inferenceEntityId, TaskType.ANY).get("models");
assertEquals(inferenceEntityId, singleModel.get(0).get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
Expand All @@ -88,7 +88,7 @@ public void testGetModelWithAnyTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testApisWithoutTaskType() throws IOException {
String modelId = "no_task_type_in_url";
putModel(modelId, mockServiceModelConfig(TaskType.SPARSE_EMBEDDING));
putModel(modelId, mockSparseServiceModelConfig(TaskType.SPARSE_EMBEDDING));
var singleModel = (List<Map<String, Object>>) getModel(modelId).get("models");
assertEquals(modelId, singleModel.get(0).get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;

import java.io.IOException;
import java.util.List;
import java.util.Map;

public class MockDenseInferenceServiceIT extends InferenceBaseRestTest {

@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("model_id"));
assertEquals(TaskType.TEXT_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
assertEquals("text_embedding_test_service", modelMap.get("service"));
}

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING);
}

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(
inferenceEntityId,
TaskType.TEXT_EMBEDDING,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
);

assertNonEmptyInferenceResults(inference, 3, TaskType.TEXT_EMBEDDING);
}

@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model"));

var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
assertNull(putServiceSettings.get("api_key"));
assertNotNull(putServiceSettings.get("model"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

import static org.hamcrest.Matchers.is;

public class MockInferenceServiceIT extends InferenceBaseRestTest {
public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {

@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

Expand All @@ -37,7 +37,7 @@ public void testMockService() throws IOException {

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(
Expand All @@ -52,7 +52,7 @@ public void testMockServiceWithMultipleInputs() throws IOException {
@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

Expand All @@ -68,7 +68,7 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti
@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

Expand All @@ -87,7 +87,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I
@SuppressWarnings("unchecked")
public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING);
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public TestServiceModel parsePersistedConfigWithSecrets(
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);

var serviceSettings = TestDenseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap);
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
Expand All @@ -70,14 +70,16 @@ public TestServiceModel parsePersistedConfigWithSecrets(
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

var serviceSettings = TestDenseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap);
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
}

protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);

@Override
public void start(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
}
return results;
}

protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
}
}

public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings {
Expand Down Expand Up @@ -214,6 +218,7 @@ public ToXContentObject getFilteredXContentObject() {
return builder;
};
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platfromArchitectures,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
Expand Down Expand Up @@ -133,6 +133,11 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
}
return List.of(new ChunkedSparseEmbeddingResults(chunks));
}

protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
}

}

public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
Expand Down

0 comments on commit 48567be

Please sign in to comment.