Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IT for VectorDBTool and NeuralSparseTool #177

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class RAGTool implements Tool {
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String OUTPUT_FIELD = "output_field";
public static final String QUERY_TYPE = "query_type";
public static final String CONTENT_GENERATION_FIELD = "enable_Content_Generation";
public static final String CONTENT_GENERATION_FIELD = "enable_content_generation";
public static final String K_FIELD = "k";
private final AbstractRetrieverTool queryTool;
private String name = TYPE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
assertEquals(indexName, responseInMap.get("index").toString());
}

protected void createIngestPipelineWithConfiguration(String pipelineName, String body) throws Exception {
Response response = makeRequest(client(), "PUT", "/_ingest/pipeline/" + pipelineName, null, body, null);
Map<String, Object> responseInMap = parseResponseToMap(response);
assertEquals("true", responseInMap.get("acknowledged").toString());
}

// Similar to deleteExternalIndices, but including indices with "." prefix vs. excluding them
protected void deleteSystemIndices() throws IOException {
final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all"));
Expand Down
515 changes: 515 additions & 0 deletions src/test/java/org/opensearch/integTest/RAGToolIT.java

Large diffs are not rendered by default.

199 changes: 199 additions & 0 deletions src/test/java/org/opensearch/integTest/VectorDBToolIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.integTest;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.junit.Assert.assertThrows;

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;

import lombok.SneakyThrows;

public class VectorDBToolIT extends BaseAgentToolsIT {

public static String TEST_INDEX_NAME = "test_index";

private String modelId;
private String registerAgentRequestBody;

@SneakyThrows
private void prepareModel() {
String requestBody = Files
.readString(
Path
.of(
this
.getClass()
.getClassLoader()
.getResource("org/opensearch/agent/tools/register_text_embedding_model_request_body.json")
.toURI()
)
);
modelId = registerModelThenDeploy(requestBody);
}

@SneakyThrows
private void prepareIndex() {

String pipelineConfig = "{\n"
+ " \"description\": \"text embedding pipeline\",\n"
+ " \"processors\": [\n"
+ " {\n"
+ " \"text_embedding\": {\n"
+ " \"model_id\": \""
+ modelId
+ "\",\n"
+ " \"field_map\": {\n"
+ " \"text\": \"embedding\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
createIngestPipelineWithConfiguration("test-embedding-model", pipelineConfig);

String indexMapping = "{\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"text\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"embedding\": {\n"
+ " \"type\": \"knn_vector\",\n"
+ " \"dimension\": 768,\n"
+ " \"method\": {\n"
+ " \"name\": \"hnsw\",\n"
+ " \"space_type\": \"l2\",\n"
+ " \"engine\": \"lucene\",\n"
+ " \"parameters\": {\n"
+ " \"ef_construction\": 128,\n"
+ " \"m\": 24\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"settings\": {\n"
+ " \"index\": {\n"
+ " \"knn.space_type\": \"cosinesimil\",\n"
+ " \"default_pipeline\": \"test-embedding-model\",\n"
+ " \"knn\": \"true\"\n"
+ " }\n"
+ " }\n"
+ "}";

createIndexWithConfiguration(TEST_INDEX_NAME, indexMapping);

addDocToIndex(TEST_INDEX_NAME, "0", List.of("text"), List.of("hello world"));

addDocToIndex(TEST_INDEX_NAME, "1", List.of("text"), List.of("a b"));
}

@Before
@SneakyThrows
public void setUp() {
super.setUp();
prepareModel();
prepareIndex();
registerAgentRequestBody = Files
.readString(
Path
.of(
this
.getClass()
.getClassLoader()
.getResource("org/opensearch/agent/tools/register_flow_agent_of_vectordb_tool_request_body.json")
.toURI()
)
);
registerAgentRequestBody = registerAgentRequestBody.replace("<MODEL_ID>", modelId);

}

@After
@SneakyThrows
public void tearDown() {
super.tearDown();
deleteExternalIndices();
mingshl marked this conversation as resolved.
Show resolved Hide resolved
deleteModel(modelId);
}

public void testVectorDBToolInFlowAgent() {
String agentId = createAgent(registerAgentRequestBody);

// match similar text, doc1 match with higher score
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}");

// To allow digits variation from model output, using string contains to match
assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.70467"));
assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.26499"));

// match exact same text case, doc0 match with higher score
String result1 = executeAgent(agentId, "{\"parameters\": {\"question\": \"hello\"}}");

// To allow digits variation from model output, using string contains to match
assertTrue(
result1.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.5671488")
);
assertTrue(result1.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.2423683"));

// if blank input, call onFailure and get exception
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);
}

public void testVectorDBToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");

// To allow digits variation from model output, using string contains to match
assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"0\",\"_score\":0.70493"));
assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"1\",\"_score\":0.26505"));

}

public void testVectorDBToolInFlowAgent_withIllegalEmbeddingField_thenThrowException() {
String agentId = createAgent(registerAgentRequestBody.replace("\"embedding\"", "\"embedding2\""));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("all shards failed"), containsString("SearchPhaseExecutionException"))
);
}

public void testVectorDBToolInFlowAgent_withIllegalIndexField_thenThrowException() {
String agentId = createAgent(registerAgentRequestBody.replace("test_index", "test_index2"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException"))
);
}

public void testVectorDBToolInFlowAgent_withIllegalModelIdField_thenThrowException() {
String agentId = createAgent(registerAgentRequestBody.replace(modelId, "test_model_id"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException")));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"name": "Test_Agent_For_RagTool",
"type": "flow",
"description": "this is a test flow agent in flow",
"tools": [
{
"type": "RAGTool",
"description": "A description of the tool",
"parameters": {
"embedding_model_id": "<MODEL_ID>",
"index": "<INDEX_NAME>",
"embedding_field": "embedding",
"query_type": "neural",
"enable_content_generation":"false",
"source_field": [
"text"
],
"input": "${parameters.question}",
"prompt": "\n\nHuman:You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say don't know. \n\n Context:\n${parameters.output_field}\n\nHuman:${parameters.question}\n\nAssistant:"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"name": "Test_VectorDB_Agent",
"type": "flow",
"tools": [
{
"type": "VectorDBTool",
"parameters": {
"description":"user this tool to search data from the test index",
"model_id": "<MODEL_ID>",
"index": "test_index",
"embedding_field": "embedding",
"source_field": ["text"],
"input": "${parameters.question}"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"name": "traced_small_model",
"version": "1.0.0",
"model_format": "TORCH_SCRIPT",
"model_task_type": "text_embedding",
"model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021",
"model_config": {
"model_type": "bert",
"embedding_dimension": 768,
"framework_type": "sentence_transformers",
"all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}"
},
"url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true"
}
Loading