From 36106d0c382b86f06ac6ce0ed329b414ebd67513 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 09:56:15 -0800 Subject: [PATCH] enhance parsing model response function for more edge cases (#2122) (#2129) * enhance parsing model response function for more edge cases Signed-off-by: Yaliang Wu * add more unit test Signed-off-by: Yaliang Wu * fine tune code; fix some bug Signed-off-by: Yaliang Wu * add more unit test Signed-off-by: Yaliang Wu * fix tool name bug Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 311b9710413fd5ad1da3d092eb591fecfb09acd6) Co-authored-by: Yaliang Wu --- .../engine/algorithms/agent/AgentUtils.java | 180 ++++++++++- .../algorithms/agent/MLChatAgentRunner.java | 126 ++------ .../algorithms/agent/MLFlowAgentRunner.java | 7 +- .../ml/engine/memory/MLMemoryManager.java | 1 - .../algorithms/agent/AgentUtilsTest.java | 279 +++++++++++++++++- 5 files changed, 473 insertions(+), 120 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index b9a47b63de..301304b556 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -5,12 +5,20 @@ package org.opensearch.ml.engine.algorithms.agent; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; @@ -19,10 +27,13 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -33,7 +44,11 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.utils.StringUtils; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class AgentUtils { public static final String SELECTED_TOOLS = "selected_tools"; @@ -167,23 +182,166 @@ public static String extractModelResponseJson(String text) { return extractModelResponseJson(text, null); } - public static String extractModelResponseJson(String text, List llmResponsePatterns) { - Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```"); - Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text); - - if (jsonBlockMatcher.find()) { - return jsonBlockMatcher.group(1); + public static Map parseLLMOutput( + ModelTensorOutput tmpModelTensorOutput, + List llmResponsePatterns, + Set inputTools + ) { + Map modelOutput = new HashMap<>(); + Map dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { + String llmReasoningResponse = (String) dataAsMap.get("response"); + String thoughtResponse = null; + try { + thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns); + modelOutput.put(THOUGHT_RESPONSE, thoughtResponse); + } catch (IllegalArgumentException e) { + modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse); + thoughtResponse = llmReasoningResponse; + } + parseThoughtResponse(modelOutput, thoughtResponse); } else { - String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS); - if (matchedPart == null && llmResponsePatterns != null) { - // If no match is found, try additional patterns if provided - matchedPart = findMatchedPart(text, llmResponsePatterns); + extractParams(modelOutput, dataAsMap, THOUGHT); + extractParams(modelOutput, dataAsMap, ACTION); + extractParams(modelOutput, dataAsMap, ACTION_INPUT); + extractParams(modelOutput, dataAsMap, FINAL_ANSWER); + try { + modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap)); + } catch (Exception e) { + log.warn("Failed to parse model response", e); + } + } + String action = modelOutput.get(ACTION); + if (action != null) { + String matchedTool = getMatchedTool(inputTools, action); + if (matchedTool != null) { + modelOutput.put(ACTION, matchedTool); + } else { + modelOutput.remove(ACTION); + } + } + if (!modelOutput.containsKey(ACTION) && !modelOutput.containsKey(FINAL_ANSWER)) { + modelOutput.put(FINAL_ANSWER, modelOutput.get(THOUGHT_RESPONSE)); + } + return modelOutput; + } + + public static String getMatchedTool(Collection tools, String action) { + for (String tool : tools) { + if (action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) { + return tool; } + } + return null; + } + + public static void extractParams(Map modelOutput, Map dataAsMap, String paramName) { + if (dataAsMap.containsKey(paramName)) { + modelOutput.put(paramName, toJson(dataAsMap.get(paramName))); + } + } + + public static String extractModelResponseJson(String text, List llmResponsePatterns) { + if (text.contains("```json")) { + text = text.substring(text.indexOf("```json") + "```json".length()); + if (text.contains("```")) { + text = text.substring(0, text.lastIndexOf("```")); + } + } + text = text.trim(); + if (isJson(text)) { + return text; + } + String matchedPart = null; + if (llmResponsePatterns != null) { + matchedPart = findMatchedPart(text, llmResponsePatterns); if (matchedPart != null) { return matchedPart; } - throw new IllegalArgumentException("Model output is invalid"); } + matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS); + if (matchedPart != null) { + return matchedPart; + } + throw new IllegalArgumentException("Model output is invalid"); + } + + public static void parseThoughtResponse(Map modelOutput, String thoughtResponse) { + if (thoughtResponse != null) { + if (isJson(thoughtResponse)) { + modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class))); + } else {// sometimes LLM return invalid json response + String thought = extractThought(thoughtResponse); + String action = extractAction(thoughtResponse); + String actionInput = extractActionInput(thoughtResponse); + String finalAnswer = extractFinalAnswer(thoughtResponse); + if (thought != null) { + modelOutput.put(THOUGHT, thought); + } + if (action != null) { + modelOutput.put(ACTION, action); + } + if (actionInput != null) { + modelOutput.put(ACTION_INPUT, actionInput); + } + if (finalAnswer != null) { + modelOutput.put(FINAL_ANSWER, finalAnswer); + } + } + } + } + + public static String extractFinalAnswer(String text) { + String result = null; + if (text.contains("\"final_answer\"")) { + String pattern = "\"final_answer\"\\s*:\\s*\"(.*?)$"; + Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); + Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text); + if (jsonBlockMatcher.find()) { + result = jsonBlockMatcher.group(1); + } + } + return result; + } + + public static String extractThought(String text) { + String result = null; + if (text.contains("\"thought\"")) { + String pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]"; + Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); + Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text); + if (jsonBlockMatcher.find()) { + result = jsonBlockMatcher.group(1); + } + } + return result; + } + + public static String extractAction(String text) { + String result = null; + if (text.contains("\"action\"")) { + String pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)"; + Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); + Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text); + if (jsonBlockMatcher.find()) { + result = jsonBlockMatcher.group(1); + } + } + return result; + } + + public static String extractActionInput(String text) { + String result = null; + if (text.contains("\"action_input\"")) { + String pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\""; + Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); // Add Pattern.DOTALL to match across newlines + Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text); + if (jsonBlockMatcher.find()) { + result = jsonBlockMatcher.group(1); + result = result.replace("\\\"", "\""); + } + } + return result; } public static String findMatchedPart(String text, List patternList) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index efbf8c2f88..bd64c36828 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,10 +7,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.common.utils.StringUtils.isJson; -import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; @@ -19,11 +16,12 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; -import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; import java.security.PrivilegedActionException; @@ -62,7 +60,6 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; @@ -189,23 +186,15 @@ private void runReAct( String sessionId, ActionListener listener ) { - final List inputTools = getToolNames(tools); - String question = parameters.get(MLAgentExecutor.QUESTION); - String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); - boolean verbose = parameters.containsKey(VERBOSE) && Boolean.parseBoolean(parameters.get(VERBOSE)); - boolean traceDisabled = parameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(parameters.get(DISABLE_TRACE)); - Map tmpParameters = constructLLMParams(llm, parameters); - String prompt = constructLLMPrompt(tools, parameters, inputTools, tmpParameters); + String prompt = constructLLMPrompt(tools, tmpParameters); tmpParameters.put(PROMPT, prompt); + final String finalPrompt = prompt; - List traceTensors = createModelTensors(sessionId, parentInteractionId); - - StringBuilder scratchpadBuilder = new StringBuilder(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); - AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); - tmpParameters.put(PROMPT, newPrompt.get()); - String finalPrompt = prompt; + String question = tmpParameters.get(MLAgentExecutor.QUESTION); + String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); + boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false")); + boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); // Create root interaction. ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; @@ -224,6 +213,12 @@ private void runReAct( lastLlmListener.set(firstListener); StepListener lastStepListener = firstListener; + StringBuilder scratchpadBuilder = new StringBuilder(); + StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); + AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + + List traceTensors = createModelTensors(sessionId, parentInteractionId); int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2; for (int i = 0; i < maxIterations; i++) { int finalI = i; @@ -234,8 +229,8 @@ private void runReAct( if (finalI % 2 == 0) { MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); - List llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class); - Map modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); + Map modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns, tools.keySet()); String thought = String.valueOf(modelOutput.get(THOUGHT)); String action = String.valueOf(modelOutput.get(ACTION)); @@ -287,9 +282,7 @@ private void runReAct( "LLM" ); - action = getMatchingTool(tools, action); - - if (tools.containsKey(action) && inputTools.contains(action)) { + if (tools.containsKey(action)) { Map toolParams = constructToolParams( tools, toolSpecMap, @@ -402,42 +395,6 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private static Map parseLLMOutput(ModelTensorOutput tmpModelTensorOutput, List llmResponsePatterns) { - Map modelOutput = new HashMap<>(); - Map dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); - if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { - String llmReasoningResponse = (String) dataAsMap.get("response"); - String thoughtResponse = null; - try { - thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns); - modelOutput.put(THOUGHT_RESPONSE, thoughtResponse); - } catch (IllegalArgumentException e) { - modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse); - modelOutput.put(FINAL_ANSWER, llmReasoningResponse); - } - if (isJson(thoughtResponse)) { - modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class))); - } - } else { - extractParams(modelOutput, dataAsMap, THOUGHT); - extractParams(modelOutput, dataAsMap, ACTION); - extractParams(modelOutput, dataAsMap, ACTION_INPUT); - extractParams(modelOutput, dataAsMap, FINAL_ANSWER); - try { - modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap)); - } catch (Exception e) { - log.warn("Failed to parse model response", e); - } - } - return modelOutput; - } - - private static void extractParams(Map modelOutput, Map dataAsMap, String paramName) { - if (dataAsMap.containsKey(paramName)) { - modelOutput.put(paramName, toJson(dataAsMap.get(paramName))); - } - } - private static List createFinalAnswerTensors(List sessionId, List lastThought) { List finalModelTensors = sessionId; finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); @@ -480,7 +437,7 @@ private static void addToolOutputToAddtionalInfo( MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { String outputString = outputToOutputString(output); - String toolOutputKey = String.format("%s.output", toolSpec.getType()); + String toolOutputKey = String.format("%s.output", getToolName(toolSpec)); if (additionalInfo.get(toolOutputKey) != null) { List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); @@ -512,7 +469,6 @@ private static void runTool( Map llmToolTmpParameters = new HashMap<>(); llmToolTmpParameters.putAll(tmpParameters); llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters()); - // TODO: support tool parameter override : langauge_model_tool.prompt llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { @@ -550,16 +506,6 @@ private static Map constructToolParams( return toolParams; } - private static String getMatchingTool(Map tools, String name) { - String toolName = name; - for (String key : tools.keySet()) { - if (name.toLowerCase().contains(key.toLowerCase())) { - toolName = key; - } - } - return toolName; - } - private static void saveTraceData( ConversationIndexMemory conversationIndexMemory, String memory, @@ -655,21 +601,16 @@ private static List createModelTensors(String sessionId, String pa return cotModelTensors; } - private static String constructLLMPrompt( - Map tools, - Map parameters, - List inputTools, - Map tmpParameters - ) { - String prompt = parameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE); + private static String constructLLMPrompt(Map tools, Map tmpParameters) { + String prompt = tmpParameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE); StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); prompt = promptSubstitutor.replace(prompt); - prompt = AgentUtils.addPrefixSuffixToPrompt(parameters, prompt); - prompt = AgentUtils.addToolsToPrompt(tools, parameters, inputTools, prompt); - prompt = AgentUtils.addIndicesToPrompt(parameters, prompt); - prompt = AgentUtils.addExamplesToPrompt(parameters, prompt); - prompt = AgentUtils.addChatHistoryToPrompt(parameters, prompt); - prompt = AgentUtils.addContextToPrompt(parameters, prompt); + prompt = AgentUtils.addPrefixSuffixToPrompt(tmpParameters, prompt); + prompt = AgentUtils.addToolsToPrompt(tools, tmpParameters, getToolNames(tools), prompt); + prompt = AgentUtils.addIndicesToPrompt(tmpParameters, prompt); + prompt = AgentUtils.addExamplesToPrompt(tmpParameters, prompt); + prompt = AgentUtils.addChatHistoryToPrompt(tmpParameters, prompt); + prompt = AgentUtils.addContextToPrompt(tmpParameters, prompt); return prompt; } @@ -699,17 +640,10 @@ private static Map constructLLMParams(LLMSpec llm, Map params, ActionListener nextStepListener = new StepListener<>(); int finalI = i; previousStepListener.whenComplete(output -> { - String key = previousToolSpec.getName(); - String outputKey = previousToolSpec.getName() != null - ? previousToolSpec.getName() + ".output" - : previousToolSpec.getType() + ".output"; + String key = getToolName(previousToolSpec); + String outputKey = key + ".output"; String outputResponse = parseResponse(output); params.put(outputKey, escapeJson(outputResponse)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index 1dc76711ec..ec7a805c9e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -139,7 +139,6 @@ public void createInteraction( * @param actionListener get all the final interactions that are not traces */ public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { - Preconditions.checkArgument(lastNInteraction > 0, "History message size must be at least 1."); log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 5be75ad935..0a0af3f60c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -10,21 +10,30 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; public class AgentUtilsTest { @@ -32,9 +41,105 @@ public class AgentUtilsTest { @Mock private Tool tool1, tool2; + private Map> llmResponseExpectedParseResults; + + private String responseForAction = "---------------------\n{\n " + + "\"thought\": \"Let me search our index to find population projections\", \n " + + "\"action\": \"VectorDBTool\",\n " + + "\"action_input\": \"Seattle population projection 2023\"\n}"; + + private String responseForActionWrongAction = "---------------------\n{\n " + + "\"thought\": \"Let me search our index to find population projections\", \n " + + "\"action\": \"Let me run VectorDBTool to get more data\",\n " + + "\"action_input\": \"Seattle population projection 2023\"\n}"; + + private String responseForActionNullAction = "---------------------\n{\n " + + "\"thought\": \"Let me search our index to find population projections\" \n }"; + + private String responseNotFollowJsonFormat = "Final answer is I don't know"; + private String responseForActionInvalidJson = "---------------------\n{\n " + + "\"thought\": \"Let me search our index to find population projections\", \n " + + "\"action\": \"VectorDBTool\",\n " + + "\"action_input\": \"Seattle population projection 2023\""; + private String responseForFinalAnswer = "---------------------```json\n{\n " + + "\"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n " + + "\"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + private String responseForFinalAnswerInvalidJson = + "\"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n " + + "\"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + + private String responseForFinalAnswerWithJson = "---------------------```json\n{\n " + + "\"thought\": \"Now I know the final answer\",\n " + + "\"final_answer\": \"PPLTool generates such query ```json source=iris_data | fields petal_length_in_cm,petal_width_in_cm | kmeans centroids=3 ```.\"\n}\n```"; + + private String wrongResponseForAction = "---------------------```json\n{\n " + + "\"thought\": \"Let's try VectorDBTool\",\n " + + "\"action\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + @Before public void setup() { MockitoAnnotations.openMocks(this); + llmResponseExpectedParseResults = new HashMap<>(); + Map responseForActionExpectedResult = Map + .of( + THOUGHT, + "Let me search our index to find population projections", + ACTION, + "VectorDBTool", + ACTION_INPUT, + "Seattle population projection 2023" + ); + llmResponseExpectedParseResults.put(responseForAction, responseForActionExpectedResult); + llmResponseExpectedParseResults.put(responseForActionWrongAction, responseForActionExpectedResult); + llmResponseExpectedParseResults.put(responseForActionInvalidJson, responseForActionExpectedResult); + Map responseForActionNullActionExpectedResult = Map + .of( + THOUGHT, + "Let me search our index to find population projections", + FINAL_ANSWER, + "{\n \"thought\": \"Let me search our index to find population projections\" \n }" + ); + llmResponseExpectedParseResults.put(responseForActionNullAction, responseForActionNullActionExpectedResult); + + Map responseNotFollowJsonFormatExpectedResult = Map.of(FINAL_ANSWER, responseNotFollowJsonFormat); + llmResponseExpectedParseResults.put(responseNotFollowJsonFormat, responseNotFollowJsonFormatExpectedResult); + + Map responseForFinalAnswerExpectedResult = Map + .of( + THOUGHT, + "Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:", + FINAL_ANSWER, + "After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius." + ); + llmResponseExpectedParseResults.put(responseForFinalAnswer, responseForFinalAnswerExpectedResult); + Map responseForFinalAnswerExpectedResultExpectedResult = Map + .of( + THOUGHT, + "Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:", + FINAL_ANSWER, + "After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```" + ); + llmResponseExpectedParseResults.put(responseForFinalAnswerInvalidJson, responseForFinalAnswerExpectedResultExpectedResult); + Map responseForFinalAnswerWithJsonExpectedResultExpectedResult = Map + .of( + THOUGHT, + "Now I know the final answer", + FINAL_ANSWER, + "PPLTool generates such query ```json source=iris_data | fields petal_length_in_cm,petal_width_in_cm | kmeans centroids=3 ```." + ); + llmResponseExpectedParseResults.put(responseForFinalAnswerWithJson, responseForFinalAnswerWithJsonExpectedResultExpectedResult); + + Map wrongResponseForActionExpectedResultExpectedResult = Map + .of( + THOUGHT, + "Let's try VectorDBTool", + FINAL_ANSWER, + "{\n" + + " \"thought\": \"Let's try VectorDBTool\",\n" + + " \"action\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + + "}" + ); + llmResponseExpectedParseResults.put(wrongResponseForAction, wrongResponseForActionExpectedResultExpectedResult); } @@ -262,7 +367,7 @@ public void testExtractModelResponseJsonWithInvalidModelOutput() { @Test public void testExtractModelResponseJsonWithValidModelOutput() { String text = - "This is the model response\n```json\n{\"thought\":\"use CatIndexTool to get index first\",\"action\":\"CatIndexTool\"}```"; + "This is the model response\n```json\n{\"thought\":\"use CatIndexTool to get index first\",\"action\":\"CatIndexTool\"} \n``` other content"; String responseJson = AgentUtils.extractModelResponseJson(text); assertEquals("{\"thought\":\"use CatIndexTool to get index first\",\"action\":\"CatIndexTool\"}", responseJson); } @@ -276,34 +381,192 @@ public void testExtractModelResponseJson_ThoughtFinalAnswer() { + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + "}"; - System.out.println(result); Assert.assertEquals(expectedResult, result); } @Test public void testExtractModelResponseJson_ThoughtFinalAnswerJsonBlock() { - String text = - "---------------------```json\n{\n \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + String text = responseForFinalAnswer; String result = AgentUtils.extractModelResponseJson(text); String expectedResult = "{\n" + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + "}"; - System.out.println(result); Assert.assertEquals(expectedResult, result); } @Test public void testExtractModelResponseJson_ThoughtActionInput() { - String text = - "---------------------\n{\n \"thought\": \"Let me search our index to find population projections\", \n \"action\": \"VectorDBTool\",\n \"action_input\": \"Seattle population projection 2023\"\n}"; + String text = responseForAction; String result = AgentUtils.extractModelResponseJson(text); String expectedResult = "{\n" + " \"thought\": \"Let me search our index to find population projections\", \n" + " \"action\": \"VectorDBTool\",\n" + " \"action_input\": \"Seattle population projection 2023\"\n" + "}"; - System.out.println(result); Assert.assertEquals(expectedResult, result); } + + @Test + public void testExtractMethods() { + List textList = List.of(responseForAction, responseForActionInvalidJson); + for (String text : textList) { + String thought = AgentUtils.extractThought(text); + String action = AgentUtils.extractAction(text); + String actionInput = AgentUtils.extractActionInput(text); + String finalAnswer = AgentUtils.extractFinalAnswer(text); + Assert.assertEquals("Let me search our index to find population projections", thought); + Assert.assertEquals("VectorDBTool\",\n ", action); + Assert.assertEquals("Seattle population projection 2023", actionInput); + Assert.assertNull(finalAnswer); + } + } + + @Test + public void testExtractMethods_FinalAnswer() { + List textList = List.of(responseForFinalAnswer, responseForFinalAnswerInvalidJson); + for (String text : textList) { + String thought = AgentUtils.extractThought(text); + String action = AgentUtils.extractAction(text); + String actionInput = AgentUtils.extractActionInput(text); + String finalAnswer = AgentUtils.extractFinalAnswer(text); + Assert + .assertEquals( + "Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:", + thought + ); + Assert.assertNull(action); + Assert.assertNull(actionInput); + Assert + .assertEquals( + "After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```", + finalAnswer + ); + } + } + + @Test + public void testParseLLMOutput() { + Set tools = Set.of("VectorDBTool", "CatIndexTool"); + for (Map.Entry> entry : llmResponseExpectedParseResults.entrySet()) { + ModelTensorOutput modelTensoOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors( + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", entry.getKey())).build()) + ) + .build() + ) + ) + .build(); + Map output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools); + for (String key : entry.getValue().keySet()) { + Assert.assertEquals(entry.getValue().get(key), output.get(key)); + } + } + } + + @Test + public void testParseLLMOutput_MultipleFields() { + Set tools = Set.of("VectorDBTool", "CatIndexTool"); + String thought = "Let me run VectorDBTool to get more information"; + String toolName = "vectordbtool"; + ModelTensorOutput modelTensoOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name("response").dataAsMap(Map.of(THOUGHT, thought, ACTION, toolName)).build() + ) + ) + .build() + ) + ) + .build(); + Map output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools); + Assert.assertEquals(3, output.size()); + Assert.assertEquals(thought, output.get(THOUGHT)); + Assert.assertEquals("VectorDBTool", output.get(ACTION)); + Set expected = Set + .of( + "{\"action\":\"vectordbtool\",\"thought\":\"Let me run VectorDBTool to get more information\"}", + "{\"thought\":\"Let me run VectorDBTool to get more information\",\"action\":\"vectordbtool\"}" + ); + Assert.assertTrue(expected.contains(output.get(THOUGHT_RESPONSE))); + } + + @Test + public void testParseLLMOutput_MultipleFields_NoActionAndFinalAnswer() { + Set tools = Set.of("VectorDBTool", "CatIndexTool"); + String key1 = "dummy key1"; + String value1 = "dummy value1"; + String key2 = "dummy key2"; + String value2 = "dummy value2"; + ModelTensorOutput modelTensoOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors( + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of(key1, value1, key2, value2)).build()) + ) + .build() + ) + ) + .build(); + Map output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools); + Assert.assertEquals(2, output.size()); + Assert.assertFalse(output.containsKey(THOUGHT)); + Assert.assertFalse(output.containsKey(ACTION)); + Set expected = Set + .of( + "{\"dummy key1\":\"dummy value1\",\"dummy key2\":\"dummy value2\"}", + "{\"dummy key2\":\"dummy value2\",\"dummy key1\":\"dummy value1\"}" + ); + Assert.assertTrue(expected.contains(output.get(THOUGHT_RESPONSE))); + Assert.assertEquals(output.get(THOUGHT_RESPONSE), output.get(FINAL_ANSWER)); + } + + @Test + public void testParseLLMOutput_OneFields_NoActionAndFinalAnswer() { + Set tools = Set.of("VectorDBTool", "CatIndexTool"); + String thought = "Let me run VectorDBTool to get more information"; + ModelTensorOutput modelTensoOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(Map.of(THOUGHT, thought)).build())) + .build() + ) + ) + .build(); + Map output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools); + Assert.assertEquals(3, output.size()); + Assert.assertEquals(thought, output.get(THOUGHT)); + Assert.assertFalse(output.containsKey(ACTION)); + Assert.assertEquals("{\"thought\":\"Let me run VectorDBTool to get more information\"}", output.get(THOUGHT_RESPONSE)); + Assert.assertEquals("{\"thought\":\"Let me run VectorDBTool to get more information\"}", output.get(FINAL_ANSWER)); + } + + @Test + public void testExtractThought_InvalidResult() { + String text = responseForActionInvalidJson; + String result = AgentUtils.extractThought(text); + Assert.assertEquals("Let me search our index to find population projections", result); + } + }