diff --git a/plugin/build.gradle b/plugin/build.gradle index d9f97cf4cf..bca23cf2e5 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -578,3 +578,8 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) { dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask") dependsOn tasks.named("${baseName}#fullRestartClusterTask") } + +forbiddenPatterns { + exclude '**/*.pdf' + exclude '**/*.jpg' +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index a7e3b9932a..e8154f4c2d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -17,19 +17,24 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.createConnector; +import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.deployRemoteModel; import static org.opensearch.ml.utils.TestHelper.makeRequest; import static org.opensearch.ml.utils.TestHelper.toHttpEntity; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Base64; import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.commons.io.FileUtils; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLTaskState; @@ -39,7 +44,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { +public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n" @@ -70,11 +75,42 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " ]\n" + "}"; + private static final String OPENAI_4o_CONNECTOR_BLUEPRINT = "{\n" + + " \"name\": \"OpenAI Chat Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 2,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"model\": \"gpt-4o-mini\",\n" + + " \"temperature\": 0\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/chat/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages}, \\\"temperature\\\": ${parameters.temperature} , \\\"max_tokens\\\": 300 }\"\n" + + " }\n" + + " ]\n" + + "}"; + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + private static final String BEDROCK_ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n" + " \"name\": \"Bedrock Connector: claude2\",\n" + " \"description\": \"The connector to bedrock claude2 model\",\n" @@ -145,10 +181,100 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " ]\n" + "}"; + private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" + + " \"name\": \"Bedrock Connector: claude 3.5\",\n" + + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model\": \"" + + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET + + "\",\n" + + " \"system_prompt\": \"You are a helpful assistant.\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/" + + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET + + "/converse\",\n" + + " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\": ${parameters.messages} , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" + + " \"name\": \"Bedrock Connector: claude 3\",\n" + + " \"description\": \"The connector to bedrock claude 3 model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model\": \"" + + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET + + "\",\n" + + " \"system_prompt\": \"You are a helpful assistant.\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/" + + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET + + "/converse\",\n" + + " \"request_body\": \"{ \\\"messages\\\": ${parameters.messages} , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" + + " }\n" + + " ]\n" + + "}"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null ? BEDROCK_CONNECTOR_BLUEPRINT2 : BEDROCK_CONNECTOR_BLUEPRINT1; + private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null + ? BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 + : BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2; + private static final String COHERE_KEY = System.getenv("COHERE_KEY"); private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n" + " \"name\": \"Cohere Chat Model\",\n" @@ -192,6 +318,22 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " ]\n" + "}"; + // In some cases, we do not want a system prompt to be sent to an LLM. + private static final String PIPELINE_TEMPLATE2 = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"retrieval_augmented_generation\": {\n" + + " \"tag\": \"%s\",\n" + + " \"description\": \"%s\",\n" + + " \"model_id\": \"%s\",\n" + // + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_field_list\": [\"%s\"]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + private static final String BM25_SEARCH_REQUEST_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -210,6 +352,63 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " }\n" + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_IMAGE_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"message_size\": %d,\n" + + " \"timeout\": %d,\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" + + " }\n" + + " }\n" + + "}"; + + private static final String BM25_SEARCH_REQUEST_WITH_DOCUMENT_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + // + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"message_size\": %d,\n" + + " \"timeout\": %d,\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" + + " }\n" + + " }\n" + + "}"; + + private static final String BM25_SEARCH_REQUEST_WITH_IMAGE_AND_DOCUMENT_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"message_size\": %d,\n" + + " \"timeout\": %d,\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}} , {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" + + " }\n" + + " }\n" + + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -229,6 +428,26 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " }\n" + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"memory_id\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"message_size\": %d,\n" + + " \"timeout\": %d,\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" + + " }\n" + + " }\n" + + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -247,18 +466,28 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + "}"; private static final String OPENAI_MODEL = "gpt-3.5-turbo"; + private static final String OPENAI_40_MODEL = "gpt-4o-mini"; private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude"; + private static final String BEDROCK_CONVERSE_ANTHROPIC_CLAUDE = "bedrock-converse/" + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET; + private static final String BEDROCK_CONVERSE_ANTHROPIC_CLAUDE_3 = "bedrock-converse/" + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET; private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/"; private static Set testDocs = Set.of("qa_doc1.json", "qa_doc2.json", "qa_doc3.json"); private static final String DEFAULT_USER_AGENT = "Kibana"; protected ClassLoader classLoader = RestMLRAGSearchProcessorIT.class.getClassLoader(); private static final String INDEX_NAME = "test"; + private static final String ML_RAG_REMOTE_MODEL_GROUP = "rag_remote_model_group"; + // "client" gets initialized by the test framework at the instance level // so we perform this per test case, not via @BeforeClass. @Before public void init() throws Exception { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + // TODO Do we really need to wait this long? This adds 20s to every test case run. + // Can we instead check the cluster state and move on? + Thread.sleep(20000); + Response response = TestHelper .makeRequest( client(), @@ -307,11 +536,11 @@ public void testBM25WithOpenAI() throws Exception { Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -353,6 +582,94 @@ public void testBM25WithOpenAI() throws Exception { assertNotNull(answer); } + @Ignore + public void testBM25WithOpenAIWithImage() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4o-mini completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + + PipelineParameters pipelineParameters = new PipelineParameters(); + pipelineParameters.tag = "testBM25WithOpenAIWithImage"; + pipelineParameters.description = "desc"; + pipelineParameters.modelId = modelId; + pipelineParameters.systemPrompt = "You are a helpful assistant"; + pipelineParameters.userInstructions = "none"; + pipelineParameters.context_field = "text"; + Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + byte[] rawImage = FileUtils + .readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile()); + String imageContent = Base64.getEncoder().encodeToString(rawImage); + + SearchRequestParameters requestParameters = new SearchRequestParameters(); + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = OPENAI_40_MODEL; + requestParameters.llmQuestion = "what is this image"; + requestParameters.systemPrompt = "You are great at answering questions"; + requestParameters.userInstructions = "Follow my instructions as best you can"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + requestParameters.imageFormat = "jpeg"; + requestParameters.imageType = "data"; + requestParameters.imageData = imageContent; + Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + Map responseMap2 = parseResponseToMap(response2); + Map ext = (Map) responseMap2.get("ext"); + assertNotNull(ext); + Map rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + // TODO handle errors such as throttling + String answer = (String) rag.get("answer"); + assertNotNull(answer); + + requestParameters = new SearchRequestParameters(); + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = OPENAI_40_MODEL; + requestParameters.llmQuestion = "what is this image"; + requestParameters.systemPrompt = "You are great at answering questions"; + requestParameters.userInstructions = "Follow my instructions as best you can"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + requestParameters.imageFormat = "jpeg"; + requestParameters.imageType = "url"; + requestParameters.imageData = + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; // imageContent; + Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + Map responseMap3 = parseResponseToMap(response3); + ext = (Map) responseMap2.get("ext"); + assertNotNull(ext); + rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + answer = (String) rag.get("answer"); + assertNotNull(answer); + } + public void testBM25WithBedrock() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { @@ -361,11 +678,11 @@ public void testBM25WithBedrock() throws Exception { Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -374,7 +691,7 @@ public void testBM25WithBedrock() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithOpenAI"; + pipelineParameters.tag = "testBM25WithBedrock"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -405,6 +722,180 @@ public void testBM25WithBedrock() throws Exception { assertNotNull(answer); } + @Ignore + public void testBM25WithBedrockConverse() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null) { + return; + } + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + + PipelineParameters pipelineParameters = new PipelineParameters(); + pipelineParameters.tag = "testBM25WithBedrockConverse"; + pipelineParameters.description = "desc"; + pipelineParameters.modelId = modelId; + pipelineParameters.systemPrompt = "You are a helpful assistant"; + pipelineParameters.userInstructions = "none"; + pipelineParameters.context_field = "text"; + Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + SearchRequestParameters requestParameters = new SearchRequestParameters(); + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE; + requestParameters.llmQuestion = "who is lincoln"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + Map responseMap2 = parseResponseToMap(response2); + Map ext = (Map) responseMap2.get("ext"); + assertNotNull(ext); + Map rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + // TODO handle errors such as throttling + String answer = (String) rag.get("answer"); + assertNotNull(answer); + } + + @Ignore + public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null) { + return; + } + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + + PipelineParameters pipelineParameters = new PipelineParameters(); + pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessages"; + pipelineParameters.description = "desc"; + pipelineParameters.modelId = modelId; + pipelineParameters.systemPrompt = "You are a helpful assistant"; + pipelineParameters.userInstructions = "none"; + pipelineParameters.context_field = "text"; + Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + byte[] rawImage = FileUtils + .readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile()); + String imageContent = Base64.getEncoder().encodeToString(rawImage); + + SearchRequestParameters requestParameters = new SearchRequestParameters(); + + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE; + requestParameters.llmQuestion = "describe the image and answer the question: would lincoln have liked this place"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + requestParameters.imageFormat = "jpeg"; + requestParameters.imageType = "data"; // Bedrock does not support URLs + requestParameters.imageData = imageContent; + Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + Map responseMap2 = parseResponseToMap(response2); + Map ext = (Map) responseMap2.get("ext"); + assertNotNull(ext); + Map rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + // TODO handle errors such as throttling + String answer = (String) rag.get("answer"); + assertNotNull(answer); + } + + @Ignore + public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null) { + return; + } + Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + + PipelineParameters pipelineParameters = new PipelineParameters(); + pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat"; + pipelineParameters.description = "desc"; + pipelineParameters.modelId = modelId; + // pipelineParameters.systemPrompt = "You are a helpful assistant"; + pipelineParameters.userInstructions = "none"; + pipelineParameters.context_field = "text"; + Response response1 = createSearchPipeline2("pipeline_test", pipelineParameters); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + byte[] docBytes = FileUtils.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "lincoln.pdf").toURI()).toFile()); + String docContent = Base64.getEncoder().encodeToString(docBytes); + + SearchRequestParameters requestParameters; + requestParameters = new SearchRequestParameters(); + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE_3; + requestParameters.llmQuestion = "use the information from the attached document to tell me something interesting about lincoln"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + requestParameters.documentFormat = "pdf"; + requestParameters.documentName = "lincoln"; + requestParameters.documentData = docContent; + Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response3.getStatusLine().getStatusCode()); + + Map responseMap3 = parseResponseToMap(response3); + Map ext = (Map) responseMap3.get("ext"); + assertNotNull(ext); + Map rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + // TODO handle errors such as throttling + String answer = (String) rag.get("answer"); + assertNotNull(answer); + } + public void testBM25WithOpenAIWithConversation() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { @@ -413,11 +904,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception { Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -426,7 +917,7 @@ public void testBM25WithOpenAIWithConversation() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithOpenAI"; + pipelineParameters.tag = "testBM25WithOpenAIWithConversation"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -462,6 +953,68 @@ public void testBM25WithOpenAIWithConversation() throws Exception { assertNotNull(interactionId); } + @Ignore + public void testBM25WithOpenAIWithConversationAndImage() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + + PipelineParameters pipelineParameters = new PipelineParameters(); + pipelineParameters.tag = "testBM25WithOpenAIWithConversationAndImage"; + pipelineParameters.description = "desc"; + pipelineParameters.modelId = modelId; + pipelineParameters.systemPrompt = "You are a helpful assistant"; + pipelineParameters.userInstructions = "none"; + pipelineParameters.context_field = "text"; + Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + String conversationId = createConversation("test_convo_1"); + SearchRequestParameters requestParameters = new SearchRequestParameters(); + requestParameters.source = "text"; + requestParameters.match = "president"; + requestParameters.llmModel = OPENAI_40_MODEL; + requestParameters.llmQuestion = "describe the image and answer the question: can you picture lincoln enjoying himself there"; + requestParameters.contextSize = 5; + requestParameters.interactionSize = 5; + requestParameters.timeout = 60; + requestParameters.conversationId = conversationId; + requestParameters.imageFormat = "jpeg"; + requestParameters.imageType = "url"; + requestParameters.imageData = + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; + Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + Map responseMap2 = parseResponseToMap(response2); + Map ext = (Map) responseMap2.get("ext"); + assertNotNull(ext); + Map rag = (Map) ext.get("retrieval_augmented_generation"); + assertNotNull(rag); + + // TODO handle errors such as throttling + String answer = (String) rag.get("answer"); + assertNotNull(answer); + + String interactionId = (String) rag.get("message_id"); + assertNotNull(interactionId); + } + public void testBM25WithBedrockWithConversation() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { @@ -470,11 +1023,11 @@ public void testBM25WithBedrockWithConversation() throws Exception { Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("Bedrock", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -483,7 +1036,7 @@ public void testBM25WithBedrockWithConversation() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrock"; + pipelineParameters.tag = "testBM25WithBedrockWithConversation"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -527,11 +1080,11 @@ public void testBM25WithCohere() throws Exception { Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("Cohere Chat Completion v1", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -579,11 +1132,11 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception { Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("Cohere Chat Completion v1", connectorId); + response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = RestMLRemoteInferenceIT.getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -592,7 +1145,7 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithCohereLlmResponseField"; + pipelineParameters.tag = "testBM25WithCohereUsingLlmResponseField"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -647,9 +1200,33 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame ); } + // No system prompt + private Response createSearchPipeline2(String pipeline, PipelineParameters parameters) throws Exception { + return makeRequest( + client(), + "PUT", + String.format(Locale.ROOT, "/_search/pipeline/%s", pipeline), + null, + toHttpEntity( + String + .format( + Locale.ROOT, + PIPELINE_TEMPLATE2, + parameters.tag, + parameters.description, + parameters.modelId, + parameters.userInstructions, + parameters.context_field + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters) throws Exception { + // TODO build these templates dynamically String httpEntity = requestParameters.llmResponseField != null ? String .format( @@ -665,6 +1242,90 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.timeout, requestParameters.llmResponseField ) + : (requestParameters.documentData != null && requestParameters.imageType != null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_IMAGE_AND_DOCUMENT_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.systemPrompt, + requestParameters.userInstructions, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout, + requestParameters.llmQuestion, + requestParameters.imageFormat, + requestParameters.imageType, + requestParameters.imageData, + requestParameters.documentFormat, + requestParameters.documentName, + requestParameters.documentData + ) + : (requestParameters.documentData != null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_DOCUMENT_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + // requestParameters.systemPrompt, + requestParameters.userInstructions, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout, + requestParameters.llmQuestion, + requestParameters.documentFormat, + requestParameters.documentName, + requestParameters.documentData + ) + : (requestParameters.conversationId != null && requestParameters.imageType != null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.conversationId, + requestParameters.systemPrompt, + requestParameters.userInstructions, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout, + requestParameters.llmQuestion, + requestParameters.imageFormat, + requestParameters.imageType, + requestParameters.imageData + ) + : (requestParameters.imageType != null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_IMAGE_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.systemPrompt, + requestParameters.userInstructions, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout, + requestParameters.llmQuestion, + requestParameters.imageFormat, + requestParameters.imageType, + requestParameters.imageData + ) : (requestParameters.conversationId == null) ? String .format( @@ -741,5 +1402,11 @@ static class SearchRequestParameters { String conversationId; String llmResponseField; + String imageFormat; + String imageType; + String imageData; + String documentFormat; + String documentName; + String documentData; } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf new file mode 100644 index 0000000000..16eddb91fd Binary files /dev/null and b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf differ diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg new file mode 100644 index 0000000000..19fa158886 Binary files /dev/null and b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg differ diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 7b1814c2a5..6e8a5544b3 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -179,7 +179,8 @@ public void processResponseAsync( chatHistory, searchResults, timeout, - params.getLlmResponseField() + params.getLlmResponseField(), + params.getLlmMessages() ), null, llmQuestion, @@ -202,7 +203,8 @@ public void processResponseAsync( chatHistory, searchResults, timeout, - params.getLlmResponseField() + params.getLlmResponseField(), + params.getLlmMessages() ), conversationId, llmQuestion, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 01dc97db75..ba4f1c9b03 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -18,6 +18,8 @@ package org.opensearch.searchpipelines.questionanswering.generative.ext; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; import org.opensearch.core.ParseField; @@ -30,6 +32,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import com.google.common.base.Preconditions; @@ -81,6 +84,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { // that contains the chat completion text, i.e. "answer". private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field"); + private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages"); + public static final int SIZE_NULL_VALUE = -1; static { @@ -94,6 +99,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD); + PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD); } @Setter @@ -132,6 +138,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private String llmResponseField; + @Setter + @Getter + private List llmMessages = new ArrayList<>(); + public GenerativeQAParameters( String conversationId, String llmModel, @@ -142,6 +152,32 @@ public GenerativeQAParameters( Integer interactionSize, Integer timeout, String llmResponseField + ) { + this( + conversationId, + llmModel, + llmQuestion, + systemPrompt, + userInstructions, + contextSize, + interactionSize, + timeout, + llmResponseField, + null + ); + } + + public GenerativeQAParameters( + String conversationId, + String llmModel, + String llmQuestion, + String systemPrompt, + String userInstructions, + Integer contextSize, + Integer interactionSize, + Integer timeout, + String llmResponseField, + List llmMessages ) { this.conversationId = conversationId; this.llmModel = llmModel; @@ -156,6 +192,9 @@ public GenerativeQAParameters( this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout; this.llmResponseField = llmResponseField; + if (llmMessages != null) { + this.llmMessages.addAll(llmMessages); + } } public GenerativeQAParameters(StreamInput input) throws IOException { @@ -168,6 +207,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.interactionSize = input.readInt(); this.timeout = input.readInt(); this.llmResponseField = input.readOptionalString(); + this.llmMessages.addAll(input.readList(MessageBlock::new)); } @Override @@ -181,7 +221,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) .field(TIMEOUT.getPreferredName(), this.timeout) - .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField); + .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField) + .field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages); } @Override @@ -197,6 +238,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(interactionSize); out.writeInt(timeout); out.writeOptionalString(llmResponseField); + out.writeList(llmMessages); } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { @@ -223,4 +265,8 @@ public boolean equals(Object o) { && (this.timeout == other.getTimeout()) && Objects.equals(this.llmResponseField, other.getLlmResponseField()); } + + public void setMessageBlock(List blockList) { + this.llmMessages = blockList; + } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 66c635b211..3202d56455 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -44,4 +44,5 @@ public class ChatCompletionInput { private String userInstructions; private Llm.ModelProvider modelProvider; private String llmResponseField; + private List llmMessages; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index f6cdfec816..6793253480 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -75,7 +75,6 @@ protected void setMlClient(MachineLearningInternalClient mlClient) { * @return */ @Override - public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener listener) { MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); @@ -113,14 +112,15 @@ protected Map getInputParameters(ChatCompletionInput chatComplet inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); String messages = PromptUtil .getChatCompletionPrompt( + chatCompletionInput.getModelProvider(), chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts() + chatCompletionInput.getContexts(), + chatCompletionInput.getLlmMessages() ); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); - // log.info("Messages to LLM: {}", messages); } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK || chatCompletionInput.getModelProvider() == ModelProvider.COHERE || chatCompletionInput.getLlmResponseField() != null) { @@ -136,6 +136,19 @@ protected Map getInputParameters(ChatCompletionInput chatComplet chatCompletionInput.getContexts() ) ); + } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) { + // Bedrock Converse API does not include the system prompt as part of the Messages block. + String messages = PromptUtil + .getChatCompletionPrompt( + chatCompletionInput.getModelProvider(), + null, + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts(), + chatCompletionInput.getLlmMessages() + ); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); } else { throw new IllegalArgumentException( "Unknown/unsupported model provider: " @@ -144,7 +157,6 @@ protected Map getInputParameters(ChatCompletionInput chatComplet ); } - // log.info("LLM input parameters: {}", inputParameters.toString()); return inputParameters; } @@ -184,6 +196,20 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, } else if (provider == ModelProvider.COHERE) { answerField = "text"; fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField); + } else if (provider == ModelProvider.BEDROCK_CONVERSE) { + Map output = (Map) dataAsMap.get("output"); + Map message = (Map) output.get("message"); + if (message != null) { + List content = (List) message.get("content"); + String answer = (String) ((Map) content.get(0)).get("text"); + answers.add(answer); + } else { + Map error = (Map) output.get("error"); + if (error == null) { + throw new RuntimeException("Unexpected output: " + output); + } + errors.add((String) error.get("message")); + } } else { throw new IllegalArgumentException( "Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field." diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index 1099b1e21f..9318b681d2 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -28,7 +28,8 @@ public interface Llm { enum ModelProvider { OPENAI, BEDROCK, - COHERE + COHERE, + BEDROCK_CONVERSE } void doChatCompletion(ChatCompletionInput input, ActionListener listener); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index ef9e9948db..24e38ac368 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -29,6 +29,7 @@ public class LlmIOUtil { public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/"; public static final String COHERE_PROVIDER_PREFIX = "cohere/"; + public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/"; public static ChatCompletionInput createChatCompletionInput( String llmModel, @@ -49,7 +50,8 @@ public static ChatCompletionInput createChatCompletionInput( chatHistory, contexts, timeoutInSeconds, - llmResponseField + llmResponseField, + null ); } @@ -61,7 +63,8 @@ public static ChatCompletionInput createChatCompletionInput( List chatHistory, List contexts, int timeoutInSeconds, - String llmResponseField + String llmResponseField, + List llmMessages ) { Llm.ModelProvider provider = null; if (llmResponseField == null) { @@ -71,6 +74,8 @@ public static ChatCompletionInput createChatCompletionInput( provider = Llm.ModelProvider.BEDROCK; } else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) { provider = Llm.ModelProvider.COHERE; + } else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) { + provider = Llm.ModelProvider.BEDROCK_CONVERSE; } } } @@ -83,7 +88,8 @@ public static ChatCompletionInput createChatCompletionInput( systemPrompt, userInstructions, provider, - llmResponseField + llmResponseField, + llmMessages ); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java new file mode 100644 index 0000000000..1dbfd4d13b --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java @@ -0,0 +1,325 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; + +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.Setter; + +public class MessageBlock implements Writeable, ToXContent { + + private static final String TEXT_BLOCK = "text"; + private static final String IMAGE_BLOCK = "image"; + private static final String DOCUMENT_BLOCK = "document"; + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.role); + out.writeList(this.blockList); + } + + public MessageBlock(StreamInput in) throws IOException { + this.role = in.readString(); + Writeable.Reader reader = input -> { + String type = input.readString(); + if (type.equals("text")) { + return new TextBlock(input); + } else if (type.equals("image")) { + return new ImageBlock(input); + } else if (type.equals("document")) { + return new DocumentBlock(input); + } else { + throw new RuntimeException("Unexpected type: " + type); + } + }; + this.blockList = in.readList(reader); + } + + public static MessageBlock fromXContent(XContentParser parser) throws IOException { + if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + return new MessageBlock(parser.map()); + } + throw new XContentParseException(parser.getTokenLocation(), "Expected [START_OBJECT], got " + parser.currentToken()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("role", this.role); + builder.startArray("content"); + for (AbstractBlock block : this.blockList) { + block.toXContent(builder, params); + } + builder.endArray(); + builder.endObject(); + return builder; + } + + public interface Block { + String getType(); + } + + public static abstract class AbstractBlock implements Block, Writeable, ToXContent { + + @Override + abstract public String getType(); + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnsupportedOperationException("Not implemented."); + } + } + + public static class TextBlock extends AbstractBlock { + + @Getter + String type = "text"; + + @Getter + @Setter + String text; + + public TextBlock(String text) { + Preconditions.checkNotNull(text, "text cannot be null."); + this.text = text; + } + + public TextBlock(StreamInput in) throws IOException { + this.text = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.type); + out.writeString(this.text); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + + builder.startObject(); + builder.field("type", "text"); + builder.field("text", this.text); + builder.endObject(); + return builder; + } + } + + public static class ImageBlock extends AbstractBlock { + + @Getter + String type = "image"; + + @Getter + @Setter + String format; + + @Getter + @Setter + String data; + + @Getter + @Setter + String url; + + public ImageBlock(Map imageBlock) { + this.format = (String) imageBlock.get("format"); + Object tmp = imageBlock.get("data"); + if (tmp != null) { + this.data = (String) tmp; + } else { + tmp = imageBlock.get("url"); + if (tmp == null) { + throw new IllegalArgumentException("data or url not found in imageBlock."); + } + this.url = (String) tmp; + } + + } + + public ImageBlock(String format, String data, String url) { + Preconditions.checkNotNull(format, "format cannot be null."); + if (data == null && url == null) { + throw new IllegalArgumentException("data and url cannot both be null."); + } + this.format = format; + this.data = data; + this.url = url; + } + + public ImageBlock(StreamInput in) throws IOException { + format = in.readString(); + data = in.readOptionalString(); + url = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.type); + out.writeString(this.format); + out.writeOptionalString(this.data); + out.writeOptionalString(this.url); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + Map imageMap = new HashMap<>(); + imageMap.put("format", this.format); + if (this.data != null) { + imageMap.put("data", this.data); + } else if (this.url != null) { + imageMap.put("url", this.url); + } + builder.field("image", imageMap); + builder.endObject(); + return builder; + } + } + + public static class DocumentBlock extends AbstractBlock { + + @Getter + String type = "document"; + + @Getter + @Setter + String format; + + @Getter + @Setter + String name; + + @Getter + @Setter + String data; + + public DocumentBlock(Map documentBlock) { + Preconditions.checkState(documentBlock.containsKey("format"), "format not found in the document block."); + Preconditions.checkState(documentBlock.containsKey("name"), "name not found in the document block."); + Preconditions.checkState(documentBlock.containsKey("data"), "data not found in the document block"); + + this.format = (String) documentBlock.get("format"); + this.name = (String) documentBlock.get("name"); + this.data = (String) documentBlock.get("data"); + } + + public DocumentBlock(String format, String name, String data) { + Preconditions.checkNotNull(format, "format cannot be null."); + Preconditions.checkNotNull(name, "name cannot be null."); + Preconditions.checkNotNull(data, "data cannot be null."); + + this.format = format; + this.name = name; + this.data = data; + } + + public DocumentBlock(StreamInput in) throws IOException { + format = in.readString(); + name = in.readString(); + data = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.type); + out.writeString(this.format); + out.writeString(this.name); + out.writeString(this.data); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject("document"); + builder.field("format", this.format); + builder.field("name", this.name); + builder.field("data", this.data); + builder.endObject(); + builder.endObject(); + return builder; + } + } + + @Getter + @Setter + private String role; + + @Getter + @Setter + private List blockList = new ArrayList<>(); + + public MessageBlock() {} + + public MessageBlock(Map map) { + setMessageBlock(map); + } + + public void setMessageBlock(Map message) { + Preconditions.checkNotNull(message, "message cannot be null."); + Preconditions.checkState(message.containsKey("role"), "message must have role."); + Preconditions.checkState(message.containsKey("content"), "message must have content."); + + this.role = (String) message.get("role"); + List> contents = (List) message.get("content"); + + for (Map content : contents) { + if (content.containsKey(TEXT_BLOCK)) { + this.blockList.add(new TextBlock((String) content.get(TEXT_BLOCK))); + } else if (content.containsKey(IMAGE_BLOCK)) { + Map imageBlock = (Map) content.get(IMAGE_BLOCK); + this.blockList.add(new ImageBlock(imageBlock)); + } else if (content.containsKey(DOCUMENT_BLOCK)) { + Map documentBlock = (Map) content.get(DOCUMENT_BLOCK); + this.blockList.add(new DocumentBlock(documentBlock)); + } + } + } + + @Override + public boolean equals(Object o) { + // TODO + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(this.role) + Objects.hashCode(this.blockList); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 3a8a21614e..9b875c6f7a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -19,15 +19,20 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.EnumSet; import java.util.List; import java.util.Locale; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.core.common.Strings; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.gson.JsonArray; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; @@ -61,20 +66,27 @@ public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory, List contexts) { - return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); + public static String getChatCompletionPrompt( + Llm.ModelProvider provider, + String question, + List chatHistory, + List contexts + ) { + return getChatCompletionPrompt(provider, DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts, null); } // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of // future prompt template management work. public static String getChatCompletionPrompt( + Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List chatHistory, - List contexts + List contexts, + List llmMessages ) { - return buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); + return buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, llmMessages); } enum ChatRole { @@ -134,37 +146,132 @@ public static String buildSingleStringPrompt( return bldr.toString(); } + /** + * Message APIs such as OpenAI's Chat Completion API and Anthropic's Messages API + * use an array of messages as input to the LLM and they are better suited for + * multi-modal interactions using text and images. + * + * @param provider + * @param systemPrompt + * @param userInstructions + * @param question + * @param chatHistory + * @param contexts + * @return + */ @VisibleForTesting static String buildMessageParameter( + Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List chatHistory, List contexts ) { + return buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, null); + } + + static String buildMessageParameter( + Llm.ModelProvider provider, + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts, + List llmMessages + ) { // TODO better prompt template management is needed here. if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { - systemPrompt = DEFAULT_SYSTEM_PROMPT; + // Some model providers such as Anthropic do not allow the system prompt as part of the message body. + userInstructions = DEFAULT_SYSTEM_PROMPT; } - JsonArray messageArray = new JsonArray(); + MessageArrayBuilder messageArrayBuilder = new MessageArrayBuilder(provider); + + // Build the system prompt (only one per conversation/session) + if (!Strings.isNullOrEmpty(systemPrompt)) { + messageArrayBuilder.startMessage(ChatRole.SYSTEM); + messageArrayBuilder.addTextContent(systemPrompt); + messageArrayBuilder.endMessage(); + } + + // Anthropic does not allow two consecutive messages of the same role + // so we combine all user messages and an array of contents. + messageArrayBuilder.startMessage(ChatRole.USER); + boolean lastRoleIsAssistant = false; + if (!Strings.isNullOrEmpty(userInstructions)) { + messageArrayBuilder.addTextContent(userInstructions); + } - messageArray.addAll(getPromptTemplateAsJsonArray(systemPrompt, userInstructions)); for (int i = 0; i < contexts.size(); i++) { - messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)).toJson()); + messageArrayBuilder.addTextContent("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); } + if (!chatHistory.isEmpty()) { // The oldest interaction first - List messages = Messages.fromInteractions(chatHistory).getMessages(); - Collections.reverse(messages); - messages.forEach(m -> messageArray.add(m.toJson())); + int idx = chatHistory.size() - 1; + Interaction firstInteraction = chatHistory.get(idx); + messageArrayBuilder.addTextContent(firstInteraction.getInput()); + messageArrayBuilder.endMessage(); + messageArrayBuilder.startMessage(ChatRole.ASSISTANT, firstInteraction.getResponse()); + messageArrayBuilder.endMessage(); + + if (chatHistory.size() > 1) { + for (int i = --idx; i >= 0; i--) { + Interaction interaction = chatHistory.get(i); + messageArrayBuilder.startMessage(ChatRole.USER, interaction.getInput()); + messageArrayBuilder.endMessage(); + messageArrayBuilder.startMessage(ChatRole.ASSISTANT, interaction.getResponse()); + messageArrayBuilder.endMessage(); + } + } + + lastRoleIsAssistant = true; + } + + if (llmMessages != null && !llmMessages.isEmpty()) { + // TODO MessageBlock can have assistant roles for few-shot prompting. + if (lastRoleIsAssistant) { + messageArrayBuilder.startMessage(ChatRole.USER); + } + for (MessageBlock message : llmMessages) { + List blockList = message.getBlockList(); + for (MessageBlock.Block block : blockList) { + switch (block.getType()) { + case "text": + messageArrayBuilder.addTextContent(((MessageBlock.TextBlock) block).getText()); + break; + case "image": + MessageBlock.ImageBlock ib = (MessageBlock.ImageBlock) block; + if (ib.getData() != null) { + messageArrayBuilder.addImageData(ib.getFormat(), ib.getData()); + } else if (ib.getUrl() != null) { + messageArrayBuilder.addImageUrl(ib.getFormat(), ib.getUrl()); + } + break; + case "document": + MessageBlock.DocumentBlock db = (MessageBlock.DocumentBlock) block; + messageArrayBuilder.addDocumentContent(db.getFormat(), db.getName(), db.getData()); + break; + default: + break; + } + } + } + } else { + if (lastRoleIsAssistant) { + messageArrayBuilder.startMessage(ChatRole.USER, "QUESTION: " + question + "\n"); + } else { + messageArrayBuilder.addTextContent("QUESTION: " + question + "\n"); + } + messageArrayBuilder.addTextContent("ANSWER:"); } - messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); - messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); - return messageArray.toString(); + messageArrayBuilder.endMessage(); + + return messageArrayBuilder.toJsonArray().toString(); } public static String getPromptTemplate(String systemPrompt, String userInstructions) { @@ -183,6 +290,24 @@ static JsonArray getPromptTemplateAsJsonArray(String systemPrompt, String userIn return messageArray; } + /* + static JsonArray getPromptTemplateAsJsonArray(Llm.ModelProvider provider, String systemPrompt, String userInstructions) { + + MessageArrayBuilder bldr = new MessageArrayBuilder(provider); + + if (!Strings.isNullOrEmpty(systemPrompt)) { + bldr.startMessage(ChatRole.SYSTEM); + bldr.addTextContent(systemPrompt); + bldr.endMessage(); + } + if (!Strings.isNullOrEmpty(userInstructions)) { + bldr.startMessage(ChatRole.USER); + bldr.addTextContent(userInstructions); + bldr.endMessage(); + } + return bldr.toJsonArray(); + }*/ + @Getter static class Messages { @@ -209,6 +334,207 @@ public static Messages fromInteractions(final List interactions) { } } + interface Content { + + // All content blocks accept text + void addText(String text); + + JsonElement toJson(); + } + + interface ImageContent extends Content { + + void addImageData(String format, String data); + + void addImageUrl(String format, String url); + } + + interface DocumentContent extends Content { + void addDocument(String format, String name, String data); + } + + interface MultimodalContent extends ImageContent, DocumentContent { + + } + + private final static String CONTENT_FIELD_TEXT = "text"; + private final static String CONTENT_FIELD_TYPE = "type"; + + static class OpenAIContent implements ImageContent { + + private JsonArray json; + + public OpenAIContent() { + this.json = new JsonArray(); + } + + @Override + public void addText(String text) { + JsonObject content = new JsonObject(); + content.add(CONTENT_FIELD_TYPE, new JsonPrimitive(CONTENT_FIELD_TEXT)); + content.add(CONTENT_FIELD_TEXT, new JsonPrimitive(text)); + json.add(content); + } + + @Override + public void addImageData(String format, String data) { + JsonObject content = new JsonObject(); + content.add("type", new JsonPrimitive("image_url")); + JsonObject urlContent = new JsonObject(); + String imageData = String.format(Locale.ROOT, "data:image/%s;base64,%s", format, data); + urlContent.add("url", new JsonPrimitive(imageData)); + content.add("image_url", urlContent); + json.add(content); + } + + @Override + public void addImageUrl(String format, String url) { + JsonObject content = new JsonObject(); + content.add("type", new JsonPrimitive("image_url")); + JsonObject urlContent = new JsonObject(); + urlContent.add("url", new JsonPrimitive(url)); + content.add("image_url", urlContent); + json.add(content); + } + + @Override + public JsonElement toJson() { + return this.json; + } + } + + static class BedrockContent implements MultimodalContent { + + private JsonArray json; + + public BedrockContent() { + this.json = new JsonArray(); + } + + public BedrockContent(String type, String value) { + this.json = new JsonArray(); + if (type.equals("text")) { + addText(value); + } + } + + @Override + public void addText(String text) { + JsonObject content = new JsonObject(); + content.add(CONTENT_FIELD_TEXT, new JsonPrimitive(text)); + json.add(content); + } + + @Override + public JsonElement toJson() { + return this.json; + } + + @Override + public void addImageData(String format, String data) { + JsonObject imageData = new JsonObject(); + imageData.add("bytes", new JsonPrimitive(data)); + JsonObject image = new JsonObject(); + image.add("format", new JsonPrimitive(format)); + image.add("source", imageData); + JsonObject content = new JsonObject(); + content.add("image", image); + json.add(content); + } + + @Override + public void addImageUrl(String format, String url) { + // Bedrock does not support image URLs. + } + + @Override + public void addDocument(String format, String name, String data) { + JsonObject documentData = new JsonObject(); + documentData.add("bytes", new JsonPrimitive(data)); + JsonObject document = new JsonObject(); + document.add("format", new JsonPrimitive(format)); + document.add("name", new JsonPrimitive(name)); + document.add("source", documentData); + JsonObject content = new JsonObject(); + content.add("document", document); + json.add(content); + } + } + + static class MessageArrayBuilder { + + private final Llm.ModelProvider provider; + private List messages = new ArrayList<>(); + private Message message = null; + private Content content = null; + + public MessageArrayBuilder(Llm.ModelProvider provider) { + // OpenAI or Bedrock Converse API + if (!EnumSet.of(Llm.ModelProvider.OPENAI, Llm.ModelProvider.BEDROCK_CONVERSE).contains(provider)) { + throw new IllegalArgumentException("Unsupported provider: " + provider); + } + this.provider = provider; + } + + public void startMessage(ChatRole role) { + this.message = new Message(); + this.message.setChatRole(role); + if (this.provider == Llm.ModelProvider.OPENAI) { + content = new OpenAIContent(); + } else if (this.provider == Llm.ModelProvider.BEDROCK_CONVERSE) { + content = new BedrockContent(); + } + } + + public void startMessage(ChatRole role, String text) { + startMessage(role); + addTextContent(text); + } + + public void endMessage() { + this.message.setContent(this.content); + this.messages.add(this.message); + message = null; + content = null; + } + + public void addTextContent(String content) { + if (this.message == null || this.content == null) { + throw new RuntimeException("You must call startMessage before calling addTextContent !!"); + } + this.content.addText(content); + } + + public void addImageData(String format, String data) { + if (this.content != null && this.content instanceof ImageContent) { + ((ImageContent) this.content).addImageData(format, data); + } + } + + public void addImageUrl(String format, String url) { + if (this.content != null && this.content instanceof ImageContent) { + ((ImageContent) this.content).addImageUrl(format, url); + } + } + + public void addDocumentContent(String format, String name, String data) { + if (this.content != null && this.content instanceof DocumentContent) { + ((DocumentContent) this.content).addDocument(format, name, data); + } + } + + public JsonArray toJsonArray() { + Preconditions + .checkState(this.message == null && this.content == null, "You must call endMessage before calling toJsonArray !!"); + + JsonArray ja = new JsonArray(); + for (Message message : messages) { + ja.add(message.toJson()); + } + return ja; + } + } + // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle // vendor specific messages. static class Message { @@ -233,6 +559,12 @@ public Message(ChatRole chatRole, String content) { setContent(content); } + public Message(ChatRole chatRole, Content content) { + this(); + setChatRole(chatRole); + setContent(content); + } + public void setChatRole(ChatRole chatRole) { this.chatRole = chatRole; json.remove(MESSAGE_FIELD_ROLE); @@ -245,6 +577,11 @@ public void setContent(String content) { json.add(MESSAGE_FIELD_CONTENT, new JsonPrimitive(this.content)); } + public void setContent(Content content) { + json.remove(MESSAGE_FIELD_CONTENT); + json.add(MESSAGE_FIELD_CONTENT, content.toJson()); + } + public JsonObject toJson() { return json; } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 49f164cdb5..23eb6f3d3a 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -25,6 +25,8 @@ import java.io.EOFException; import java.io.IOException; +import java.util.List; +import java.util.Map; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; @@ -33,10 +35,22 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentHelper; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { + private List messageList = null; + + public GenerativeQAParamExtBuilderTests() { + Map imageMap = Map.of("image", Map.of("format", "jpg", "url", "https://xyz.com/file.jpg")); + Map textMap = Map.of("text", "what is this"); + Map contentMap = Map.of(); + Map map = Map.of("role", "user", "content", List.of(textMap, imageMap)); + MessageBlock mb = new MessageBlock(map); + messageList = List.of(mb); + } + public void testCtor() throws IOException { GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); GenerativeQAParameters parameters = new GenerativeQAParameters( @@ -115,7 +129,7 @@ public void testParse() throws IOException { } public void testXContentRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null, messageList); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index c36dcdb2a5..e5caa70ed7 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -24,6 +24,7 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.opensearch.action.search.SearchRequest; import org.opensearch.core.common.io.stream.StreamOutput; @@ -31,10 +32,22 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentGenerator; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParametersTests extends OpenSearchTestCase { + private List messageList = null; + + public GenerativeQAParametersTests() { + Map imageMap = Map.of("image", Map.of("format", "jpg", "url", "https://xyz.com/file.jpg")); + Map textMap = Map.of("text", "what is this"); + Map contentMap = Map.of(); + Map map = Map.of("role", "user", "content", List.of(textMap, imageMap)); + MessageBlock mb = new MessageBlock(map); + messageList = List.of(mb); + } + public void testGenerativeQAParameters() { GenerativeQAParameters params = new GenerativeQAParameters( "conversation_id", @@ -55,6 +68,29 @@ public void testGenerativeQAParameters() { assertEquals(params, actual); } + public void testGenerativeQAParametersWithLlmMessages() { + + GenerativeQAParameters params = new GenerativeQAParameters( + "conversation_id", + "llm_model", + "llm_question", + "system_prompt", + "user_instructions", + null, + null, + null, + null, + this.messageList + ); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + // MessageBlock messageBlock = actual.getMessageBlock(); + assertEquals(params, actual); + } + static class DummyStreamOutput extends StreamOutput { List list = new ArrayList<>(); @@ -62,6 +98,7 @@ static class DummyStreamOutput extends StreamOutput { @Override public void writeString(String str) { + System.out.println("Adding string: " + str); list.add(str); } @@ -123,12 +160,13 @@ public void testWriteTo() throws IOException { contextSize, interactionSize, timeout, - llmResponseField + llmResponseField, + messageList ); StreamOutput output = new DummyStreamOutput(); parameters.writeTo(output); List actual = ((DummyStreamOutput) output).getList(); - assertEquals(6, actual.size()); + assertEquals(12, actual.size()); assertEquals(conversationId, actual.get(0)); assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); @@ -190,7 +228,8 @@ public void testToXConent() throws IOException { null, null, null, - null + null, + messageList ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index f3a4bf8284..d70739b8cd 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -43,6 +43,7 @@ public void testCtor() { systemPrompt, userInstructions, Llm.ModelProvider.OPENAI, + null, null ); @@ -81,6 +82,7 @@ public void testGettersSetters() { systemPrompt, userInstructions, Llm.ModelProvider.OPENAI, + null, null ); assertEquals(model, input.getModel()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 2dc06366f8..5e5f72b59a 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -93,7 +93,7 @@ public void testBuildMessageParameter() { ) ) ); - String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); + String parameter = PromptUtil.getChatCompletionPrompt(Llm.ModelProvider.BEDROCK_CONVERSE, question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); } @@ -120,6 +120,7 @@ public void testChatCompletionApi() throws Exception { "prompt", "instructions", Llm.ModelProvider.OPENAI, + null, null ); doAnswer(invocation -> { @@ -164,6 +165,56 @@ public void testChatCompletionApiForBedrock() throws Exception { "prompt", "instructions", Llm.ModelProvider.BEDROCK, + null, + null + ); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertEquals("answer", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + } + + public void testMessageApiForBedrockConverse() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map messageMap = Map.of("role", "agent", "content", "answer"); + Map text = Map.of("text", "answer"); + List list = List.of(text); + Map content = Map.of("content", list); + Map message = Map.of("message", content); + Map dataAsMap = Map.of("output", message); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "bedrock-converse/model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK_CONVERSE, + null, null ); doAnswer(invocation -> { @@ -208,6 +259,7 @@ public void testChatCompletionApiForCohere() throws Exception { "prompt", "instructions", Llm.ModelProvider.COHERE, + null, null ); doAnswer(invocation -> { @@ -253,6 +305,7 @@ public void testChatCompletionApiForCohereWithError() throws Exception { "prompt", "instructions", Llm.ModelProvider.COHERE, + null, null ); doAnswer(invocation -> { @@ -300,7 +353,8 @@ public void testChatCompletionApiForFoo() throws Exception { "prompt", "instructions", null, - llmRespondField + llmRespondField, + null ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -347,7 +401,8 @@ public void testChatCompletionApiForFooWithError() throws Exception { "prompt", "instructions", null, - llmRespondField + llmRespondField, + null ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -395,7 +450,8 @@ public void testChatCompletionApiForFooWithErrorUnknownMessageField() throws Exc "prompt", "instructions", null, - llmRespondField + llmRespondField, + null ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -443,7 +499,8 @@ public void testChatCompletionApiForFooWithErrorUnknownErrorField() throws Excep "prompt", "instructions", null, - llmRespondField + llmRespondField, + null ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -489,6 +546,7 @@ public void testChatCompletionThrowingError() throws Exception { "prompt", "instructions", Llm.ModelProvider.OPENAI, + null, null ); @@ -536,6 +594,7 @@ public void testChatCompletionBedrockThrowingError() throws Exception { "prompt", "instructions", Llm.ModelProvider.BEDROCK, + null, null ); doAnswer(invocation -> { @@ -585,6 +644,7 @@ public void testIllegalArgument1() { "prompt", "instructions", null, + null, null ); connector.doChatCompletion(input, ActionListener.wrap(r -> {}, e -> {})); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java new file mode 100644 index 0000000000..62c6381b55 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +public class MessageBlockTests extends OpenSearchTestCase { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + public void testStreamRoundTrip() throws Exception { + MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); + MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "doc1", "data"); + List blocks = List.of(tb, ib, ib2, db); + MessageBlock mb = new MessageBlock(); + mb.setRole("user"); + mb.setBlockList(blocks); + BytesStreamOutput bso = new BytesStreamOutput(); + mb.writeTo(bso); + MessageBlock read = new MessageBlock(bso.bytes().streamInput()); + assertEquals(mb, read); + } + + public void testFromXContentParseError() throws IOException { + exceptionRule.expect(XContentParseException.class); + + MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); + // MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); + MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock(Map.of("format", "png", "data", "xyz")); + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "doc1", "data"); + List blocks = List.of(tb, ib, ib2, db); + MessageBlock mb = new MessageBlock(); + mb.setRole("user"); + mb.setBlockList(blocks); + try (XContentBuilder builder = XContentBuilder.builder(randomFrom(XContentType.values()).xContent())) { + mb.toXContent(builder, ToXContent.EMPTY_PARAMS); + try (XContentBuilder shuffled = shuffleXContent(builder); XContentParser parser = createParser(shuffled)) { + // read = TaskResult.PARSER.apply(parser, null); + MessageBlock.fromXContent(parser); + } + } finally { + // throw new IOException("Error processing [" + mb + "]", e); + } + } + + public void testInvalidImageBlock1() { + exceptionRule.expect(IllegalArgumentException.class); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock(Map.of("format", "png")); + } + + public void testInvalidImageBlock2() { + exceptionRule.expect(IllegalArgumentException.class); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", null, null); + } + + public void testInvalidDocumentBlock1() { + exceptionRule.expect(NullPointerException.class); + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(null, null, null); + } + + public void testInvalidDocumentBlock2() { + exceptionRule.expect(IllegalStateException.class); + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(Map.of()); + } + + public void testDocumentBlockCtor1() { + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(Map.of("format", "pdf", "name", "doc", "data", "xyz")); + assertEquals(db.format, "pdf"); + assertEquals(db.name, "doc"); + assertEquals(db.data, "xyz"); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index a3aedf4e5d..0d82a18a15 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -26,12 +26,19 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class PromptUtilTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + public void testPromptUtilStaticMethods() { assertNull(PromptUtil.getQuestionRephrasingPrompt("question", Collections.emptyList())); } @@ -72,7 +79,50 @@ public void testBuildMessageParameter() { ); contexts.add("context 1"); contexts.add("context 2"); - String parameter = PromptUtil.buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); + String parameter = PromptUtil + .buildMessageParameter(Llm.ModelProvider.BEDROCK_CONVERSE, systemPrompt, userInstructions, question, chatHistory, contexts); + Map parameters = Map.of("model", "foo", "messages", parameter); + assertTrue(isJson(parameter)); + } + + public void testBuildMessageParameterForOpenAI() { + String systemPrompt = "You are the best."; + String userInstructions = null; + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); + contexts.add("context 1"); + contexts.add("context 2"); + String parameter = PromptUtil + .buildMessageParameter(Llm.ModelProvider.OPENAI, systemPrompt, userInstructions, question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); } @@ -117,6 +167,139 @@ public void testBuildBedrockInputParameter() { assertTrue(parameter.contains(systemPrompt)); } + public void testBuildBedrockConverseInputParameter() { + String systemPrompt = "You are the best."; + String userInstructions = null; + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); + contexts.add("context 1"); + contexts.add("context 2"); + MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); + MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "file1", "data"); + List blocks = List.of(tb, ib, db); + MessageBlock mb = new MessageBlock(); + mb.setBlockList(blocks); + List llmMessages = List.of(mb); + String parameter = PromptUtil + .buildMessageParameter( + Llm.ModelProvider.BEDROCK_CONVERSE, + systemPrompt, + userInstructions, + question, + chatHistory, + contexts, + llmMessages + ); + assertTrue(parameter.contains(systemPrompt)); + } + + public void testBuildOpenAIInputParameter() { + String systemPrompt = "You are the best."; + String userInstructions = null; + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); + contexts.add("context 1"); + contexts.add("context 2"); + MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); + MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); + MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); + List blocks = List.of(tb, ib, ib2); + MessageBlock mb = new MessageBlock(); + mb.setBlockList(blocks); + List llmMessages = List.of(mb); + String parameter = PromptUtil + .buildMessageParameter(Llm.ModelProvider.OPENAI, systemPrompt, userInstructions, question, chatHistory, contexts, llmMessages); + assertTrue(parameter.contains(systemPrompt)); + } + + public void testGetPromptTemplate() { + String systemPrompt = "you are a helpful assistant."; + String userInstructions = "lay out your answer as a sequence of steps."; + String actual = PromptUtil.getPromptTemplate(systemPrompt, userInstructions); + assertTrue(actual.contains(systemPrompt)); + assertTrue(actual.contains(userInstructions)); + } + + public void testMessageCtor() { + PromptUtil.Message message = new PromptUtil.Message(PromptUtil.ChatRole.USER, new PromptUtil.OpenAIContent()); + assertEquals(message.getChatRole(), PromptUtil.ChatRole.USER); + } + + public void testBedrockContentCtor() { + PromptUtil.Content content = new PromptUtil.BedrockContent("text", "foo"); + assertTrue(content.toJson().toString().contains("foo")); + } + + public void testMessageArrayBuilderCtor1() { + exceptionRule.expect(IllegalArgumentException.class); + PromptUtil.MessageArrayBuilder builder = new PromptUtil.MessageArrayBuilder(Llm.ModelProvider.COHERE); + } + + public void testMessageArrayBuilderInvalidUsage1() { + exceptionRule.expect(RuntimeException.class); + PromptUtil.MessageArrayBuilder builder = new PromptUtil.MessageArrayBuilder(Llm.ModelProvider.OPENAI); + builder.addTextContent("boom"); + } + private boolean isJson(String Json) { try { new JSONObject(Json);