Skip to content

Commit

Permalink
fine tune prompt;refactor conversational agent code (#2094)
Browse files Browse the repository at this point in the history
* fine tune prompt;refactor conversational agent code

Signed-off-by: Yaliang Wu <[email protected]>

* put listener to last

Signed-off-by: Yaliang Wu <[email protected]>

* address comments

Signed-off-by: Yaliang Wu <[email protected]>

* check if selectedToolsStr is empty

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Feb 14, 2024
1 parent f5276ce commit ad14420
Show file tree
Hide file tree
Showing 12 changed files with 754 additions and 559 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -28,13 +27,24 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
public static final String PROMPT_PREFIX = "prompt.prefix";
public static final String PROMPT_SUFFIX = "prompt.suffix";
public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
public static final String TOOL_RESPONSE = "prompt.tool_response";
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
public static final String DISABLE_TRACE = "disable_trace";
public static final String VERBOSE = "verbose";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
if (parameters.containsKey(EXAMPLES)) {
Expand Down Expand Up @@ -150,17 +160,43 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
return prompt;
}

public static List<String> MODEL_RESPONSE_PATTERNS = List
.of("\\{\\s*(\"(thought|action|action_input|final_answer)\"\\s*:\\s*\".*?\"\\s*,?\\s*)+\\}");

public static String extractModelResponseJson(String text) {
Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher = pattern.matcher(text);
return extractModelResponseJson(text, null);
}

if (matcher.find()) {
return matcher.group(1);
public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);

if (jsonBlockMatcher.find()) {
return jsonBlockMatcher.group(1);
} else {
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart == null && llmResponsePatterns != null) {
// If no match is found, try additional patterns if provided
matchedPart = findMatchedPart(text, llmResponsePatterns);
}
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String findMatchedPart(String text, List<String> patternList) {
for (String p : patternList) {
Pattern pattern = Pattern.compile(p);
Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group();
}
}
return null;
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
Expand All @@ -179,16 +215,6 @@ public static String outputToOutputString(Object output) throws PrivilegedAction
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}

public static int getMessageHistoryLimit(Map<String, String> params) {
String messageHistoryLimitStr = params.get(MESSAGE_HISTORY_LIMIT);
return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS;
Expand All @@ -197,4 +223,75 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = mlAgent.getTools();
if (!Strings.isEmpty(selectedToolsStr)) {
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
}
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
for (String tool : selectedTools) {
if (toolNameSpecMap.containsKey(tool)) {
selectedToolSpecs.add(toolNameSpecMap.get(tool));
}
}
toolSpecs = selectedToolSpecs;
}
return toolSpecs;
}

public static void createTools(
Map<String, Tool.Factory> toolFactories,
Map<String, String> params,
List<MLToolSpec> toolSpecs,
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap
) {
for (MLToolSpec toolSpec : toolSpecs) {
Tool tool = createTool(toolFactories, params, toolSpec);
tools.put(tool.getName(), tool);
toolSpecMap.put(tool.getName(), toolSpec);
}
}

public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
if (params.containsKey(toolName + ".description")) {
tool.setDescription(params.get(toolName + ".description"));
}

return tool;
}

public static List<String> getToolNames(Map<String, Tool> tools) {
final List<String> inputTools = new ArrayList<>();
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
String toolName = entry.getValue().getName();
inputTools.add(toolName);
}
return inputTools;
}
}
Loading

0 comments on commit ad14420

Please sign in to comment.