Skip to content

Commit

Permalink
changed to smaller models and refine IT tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Feb 19, 2024
1 parent 18456c4 commit 982dce8
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class RAGTool implements Tool {
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String OUTPUT_FIELD = "output_field";
public static final String QUERY_TYPE = "query_type";
public static final String CONTENT_GENERATION_FIELD = "enable_Content_Generation";
public static final String CONTENT_GENERATION_FIELD = "enable_content_generation";
public static final String K_FIELD = "k";
private final AbstractRetrieverTool queryTool;
private String name = TYPE;
Expand Down
231 changes: 201 additions & 30 deletions src/test/java/org/opensearch/integTest/RAGToolIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,48 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.agent.tools.RAGTool;
import org.opensearch.client.ResponseException;

import lombok.SneakyThrows;

public class RAGToolIT extends BaseAgentToolsIT {
public class RAGToolIT extends ToolIntegrationTest {

public static String TEST_NEURAL_INDEX_NAME = "test_neural_index";
public static String TEST_NEURAL_SPARSE_INDEX_NAME = "test_neural_sparse_index";
private String textEmbeddingModelId;
private String sparseEncodingModelId;
private String largeLanguageModelId;
private String registerAgentWithNeuralQueryRequestBody;
private String registerAgentWithNeuralSparseQueryRequestBody;
private String registerAgentWithNeuralQueryAndLLMRequestBody;
private String mockLLMResponseWithSource = "{\n"
+ " \"inference_results\": [\n"
+ " {\n"
+ " \"output\": [\n"
+ " {\n"
+ " \"name\": \"response\",\n"
+ " \"result\": \"\"\" Based on the context given:\n"
+ " a, b, c are alphabets.\"\"\"\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ " ]\n"
+ "}";
private String mockLLMResponseWithoutSource = "{\n"
+ " \"inference_results\": [\n"
+ " {\n"
+ " \"output\": [\n"
+ " {\n"
+ " \"name\": \"response\",\n"
+ " \"result\": \"\"\" Based on the context given:\n"
+ " I do not see any information about a, b, c\". So I would have to say I don't know the answer to your question based on this context..\"\"\"\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ " ]\n"
+ "}";
private String registerAgentWithNeuralSparseQueryAndLLMRequestBody;

public RAGToolIT() throws IOException, URISyntaxException {}

Expand Down Expand Up @@ -61,6 +91,7 @@ private void prepareModel() {
)
);
sparseEncodingModelId = registerModelThenDeploy(requestBody1);
largeLanguageModelId = this.modelId;
}

@SneakyThrows
Expand Down Expand Up @@ -115,7 +146,7 @@ private void prepareIndex() {
+ " },\n"
+ " \"embedding\": {\n"
+ " \"type\": \"knn_vector\",\n"
+ " \"dimension\": 384,\n"
+ " \"dimension\": 768,\n"
+ " \"method\": {\n"
+ " \"name\": \"hnsw\",\n"
+ " \"space_type\": \"l2\",\n"
Expand Down Expand Up @@ -171,24 +202,36 @@ public void setUp() {
.replace("<MODEL_ID>", sparseEncodingModelId)
.replace("<INDEX_NAME>", TEST_NEURAL_SPARSE_INDEX_NAME)
.replace("\"query_type\": \"neural\"", "\"query_type\": \"neural_sparse\"");

registerAgentWithNeuralQueryAndLLMRequestBody = registerAgentWithNeuralQueryRequestBodyFile
.replace("<MODEL_ID>", textEmbeddingModelId + "\" ,\n \"inference_model_id\": \"" + largeLanguageModelId)
.replace("<INDEX_NAME>", TEST_NEURAL_INDEX_NAME)
.replace("false", "true");
registerAgentWithNeuralSparseQueryAndLLMRequestBody = registerAgentWithNeuralQueryRequestBodyFile
.replace("<MODEL_ID>", sparseEncodingModelId + "\" ,\n \"inference_model_id\": \"" + largeLanguageModelId)
.replace("<INDEX_NAME>", TEST_NEURAL_SPARSE_INDEX_NAME)
.replace("\"query_type\": \"neural\"", "\"query_type\": \"neural_sparse\"")
.replace("false", "true");
}

@After
@SneakyThrows
public void tearDown() {
super.tearDown();
deleteExternalIndices();
deleteModel(textEmbeddingModelId);
deleteModel(sparseEncodingModelId);
}

public void testRAGToolWithNeuralQueryInFlowAgent() {
public void testRAGToolWithNeuralQuery() {
String agentId = createAgent(registerAgentWithNeuralQueryRequestBody);

// neural query to test match similar text, doc1 match with higher score
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}");
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.60735726}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.3785958}\n",
"{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.7046764}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.2649903}\n",
result
);

Expand All @@ -197,8 +240,8 @@ public void testRAGToolWithNeuralQueryInFlowAgent() {

assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.70875686}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.39044854}\n",
"{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.56714886}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.24236833}\n",
result1
);

Expand All @@ -208,12 +251,30 @@ public void testRAGToolWithNeuralQueryInFlowAgent() {
org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("illegal_argument_exception"))
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);

}

public void testRAGToolWithNeuralQueryAndLLM() {
String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody);

// neural query to test match similar text, doc1 match with higher score
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"use RAGTool to answer c\"}}");
assertEquals(mockLLMResponseWithSource, result);

// if blank input, call onFailure and get exception
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);

}

public void testRAGToolWithNeuralSparseQueryInFlowAgent() {
public void testRAGToolWithNeuralSparseQuery() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody);

// neural sparse query to test match extract same text, doc1 match with high score
Expand All @@ -234,11 +295,28 @@ public void testRAGToolWithNeuralSparseQueryInFlowAgent() {
org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("illegal_argument_exception"))
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);
}

public void testRAGToolWithNeuralSparseQueryAndLLM() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody);

// neural sparse query to test match extract same text, doc1 match with high score
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"use RAGTool to answer a\"}}");
assertEquals(mockLLMResponseWithSource, result);

// if blank input, call onFailure and get exception
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);
}

public void testRAGToolWithNeuralSparseQueryInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
public void testRAGToolWithNeuralSparseQuery_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(
Expand All @@ -248,28 +326,48 @@ public void testRAGToolWithNeuralSparseQueryInFlowAgent_withIllegalSourceField_t
);
}

public void testRAGToolWithNeuralQueryInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(mockLLMResponseWithoutSource, result);
}

public void testRAGToolWithNeuralQuery_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"1\",\"_score\":0.7572355}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"0\",\"_score\":0.38389856}\n",
"{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"0\",\"_score\":0.70493275}\n"
+ "{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"1\",\"_score\":0.2650575}\n",
result
);
}

public void testRAGToolWithNeuralQueryAndLLM_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(mockLLMResponseWithoutSource, result);
}

public void testRAGToolWithNeuralSparseQuery_withIllegalEmbeddingField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace("\"embedding\"", "\"embedding2\""));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(
containsString("failed to create query: [neural_sparse] query only works on [rank_features] fields"),
containsString("search_phase_execution_exception")
)
allOf(containsString("all shards failed"), containsString("SearchPhaseExecutionException"))
);
}

public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalEmbeddingField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace("\"embedding\"", "\"embedding2\""));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("all shards failed"), containsString("SearchPhaseExecutionException"))
);
}

Expand All @@ -280,48 +378,121 @@ public void testRAGToolWithNeuralQuery_withIllegalEmbeddingField_thenThrowExcept
org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(
containsString("failed to create query: Field 'embedding2' is not knn_vector type."),
containsString("query_shard_exception")
)
allOf(containsString("all shards failed"), containsString("SearchPhaseExecutionException"))
);
}

public void testRAGToolWithNeuralQueryAndLLM_withIllegalEmbeddingField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace("\"embedding\"", "\"embedding2\""));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("all shards failed"), containsString("SearchPhaseExecutionException"))
);
}

public void testRAGToolWithNeuralSparseQueryInFlowAgent_withIllegalIndexField_thenThrowException() {
public void testRAGToolWithNeuralSparseQuery_withIllegalIndexField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace(TEST_NEURAL_SPARSE_INDEX_NAME, "test_index2"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("no such index [test_index2]"), containsString("index_not_found_exception"))
allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException"))
);
}

public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalIndexField_thenThrowException() {
String agentId = createAgent(
registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace(TEST_NEURAL_SPARSE_INDEX_NAME, "test_index2")
);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException"))
);
}

public void testRAGToolWithNeuralQueryInFlowAgent_withIllegalIndexField_thenThrowException() {
public void testRAGToolWithNeuralQuery_withIllegalIndexField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace(TEST_NEURAL_INDEX_NAME, "test_index2"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("no such index [test_index2]"), containsString("index_not_found_exception"))
allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException"))
);
}

public void testRAGToolWithNeuralQueryAndLLM_withIllegalIndexField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace(TEST_NEURAL_INDEX_NAME, "test_index2"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(
exception.getMessage(),
allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException"))
);
}

public void testRAGToolWithNeuralSparseQueryInFlowAgent_withIllegalModelIdField_thenThrowException() {
public void testRAGToolWithNeuralSparseQuery_withIllegalModelIdField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace(sparseEncodingModelId, "test_model_id"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("status_exception")));
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException")));
}

public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalModelIdField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace(sparseEncodingModelId, "test_model_id"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException")));
}

public void testRAGToolWithNeuralQueryInFlowAgent_withIllegalModelIdField_thenThrowException() {
public void testRAGToolWithNeuralQuery_withIllegalModelIdField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace(textEmbeddingModelId, "test_model_id"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("status_exception")));
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException")));
}

public void testRAGToolWithNeuralQueryAndLLM_withIllegalModelIdField_thenThrowException() {
String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace(textEmbeddingModelId, "test_model_id"));
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"));

org.hamcrest.MatcherAssert
.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException")));
}

@Override
List<PromptHandler> promptHandlers() {
PromptHandler RAGToolHandler = new PromptHandler() {
@Override
String response(String prompt) {
if (prompt.contains("RAGTool")) {
return mockLLMResponseWithSource;
} else {
return mockLLMResponseWithoutSource;
}
}

@Override
boolean apply(String prompt) {
return true;
}
};
return List.of(RAGToolHandler);
}

@Override
String toolType() {
return RAGTool.TYPE;
}
}
Loading

0 comments on commit 982dce8

Please sign in to comment.