Skip to content

Commit

Permalink
enhance parsing model response function for more edge cases (#2122) (#…
Browse files Browse the repository at this point in the history
…2129)

* enhance parsing model response function for more edge cases

Signed-off-by: Yaliang Wu <[email protected]>

* add more unit test

Signed-off-by: Yaliang Wu <[email protected]>

* fine tune code; fix some bug

Signed-off-by: Yaliang Wu <[email protected]>

* add more unit test

Signed-off-by: Yaliang Wu <[email protected]>

* fix tool name bug

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit 311b971)

Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Feb 19, 2024
1 parent 70f1b8d commit 36106d0
Show file tree
Hide file tree
Showing 5 changed files with 473 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
Expand All @@ -19,10 +27,13 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -33,7 +44,11 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
Expand Down Expand Up @@ -167,23 +182,166 @@ public static String extractModelResponseJson(String text) {
return extractModelResponseJson(text, null);
}

public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);

if (jsonBlockMatcher.find()) {
return jsonBlockMatcher.group(1);
public static Map<String, String> parseLLMOutput(
ModelTensorOutput tmpModelTensorOutput,
List<String> llmResponsePatterns,
Set<String> inputTools
) {
Map<String, String> modelOutput = new HashMap<>();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
String llmReasoningResponse = (String) dataAsMap.get("response");
String thoughtResponse = null;
try {
thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);
modelOutput.put(THOUGHT_RESPONSE, thoughtResponse);
} catch (IllegalArgumentException e) {
modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse);
thoughtResponse = llmReasoningResponse;
}
parseThoughtResponse(modelOutput, thoughtResponse);
} else {
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart == null && llmResponsePatterns != null) {
// If no match is found, try additional patterns if provided
matchedPart = findMatchedPart(text, llmResponsePatterns);
extractParams(modelOutput, dataAsMap, THOUGHT);
extractParams(modelOutput, dataAsMap, ACTION);
extractParams(modelOutput, dataAsMap, ACTION_INPUT);
extractParams(modelOutput, dataAsMap, FINAL_ANSWER);
try {
modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap));
} catch (Exception e) {
log.warn("Failed to parse model response", e);
}
}
String action = modelOutput.get(ACTION);
if (action != null) {
String matchedTool = getMatchedTool(inputTools, action);
if (matchedTool != null) {
modelOutput.put(ACTION, matchedTool);
} else {
modelOutput.remove(ACTION);
}
}
if (!modelOutput.containsKey(ACTION) && !modelOutput.containsKey(FINAL_ANSWER)) {
modelOutput.put(FINAL_ANSWER, modelOutput.get(THOUGHT_RESPONSE));
}
return modelOutput;
}

public static String getMatchedTool(Collection<String> tools, String action) {
for (String tool : tools) {
if (action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) {
return tool;
}
}
return null;
}

public static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {
if (dataAsMap.containsKey(paramName)) {
modelOutput.put(paramName, toJson(dataAsMap.get(paramName)));
}
}

public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
if (text.contains("```json")) {
text = text.substring(text.indexOf("```json") + "```json".length());
if (text.contains("```")) {
text = text.substring(0, text.lastIndexOf("```"));
}
}
text = text.trim();
if (isJson(text)) {
return text;
}
String matchedPart = null;
if (llmResponsePatterns != null) {
matchedPart = findMatchedPart(text, llmResponsePatterns);
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}
matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}

public static void parseThoughtResponse(Map<String, String> modelOutput, String thoughtResponse) {
if (thoughtResponse != null) {
if (isJson(thoughtResponse)) {
modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class)));
} else {// sometimes LLM return invalid json response
String thought = extractThought(thoughtResponse);
String action = extractAction(thoughtResponse);
String actionInput = extractActionInput(thoughtResponse);
String finalAnswer = extractFinalAnswer(thoughtResponse);
if (thought != null) {
modelOutput.put(THOUGHT, thought);
}
if (action != null) {
modelOutput.put(ACTION, action);
}
if (actionInput != null) {
modelOutput.put(ACTION_INPUT, actionInput);
}
if (finalAnswer != null) {
modelOutput.put(FINAL_ANSWER, finalAnswer);
}
}
}
}

public static String extractFinalAnswer(String text) {
String result = null;
if (text.contains("\"final_answer\"")) {
String pattern = "\"final_answer\"\\s*:\\s*\"(.*?)$";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractThought(String text) {
String result = null;
if (text.contains("\"thought\"")) {
String pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractAction(String text) {
String result = null;
if (text.contains("\"action\"")) {
String pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractActionInput(String text) {
String result = null;
if (text.contains("\"action_input\"")) {
String pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\"";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); // Add Pattern.DOTALL to match across newlines
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
result = result.replace("\\\"", "\"");
}
}
return result;
}

public static String findMatchedPart(String text, List<String> patternList) {
Expand Down
Loading

0 comments on commit 36106d0

Please sign in to comment.