Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 12, 2024
1 parent 07d08cb commit 86223c7
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

log.info("++++++++++++++++++++++++++++++++++++++++++++++++++");
System.out.println(payload);
log.info("--------------------------------------------------");
if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -28,13 +27,20 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
public static final String PROMPT_PREFIX = "prompt.prefix";
public static final String PROMPT_SUFFIX = "prompt.suffix";
public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
public static final String TOOL_RESPONSE = "prompt.tool_response";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
if (parameters.containsKey(EXAMPLES)) {
Expand Down Expand Up @@ -150,13 +156,25 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
return prompt;
}

public static List<String> MODEL_RESPONSE_PATTERNS = List.of(
"\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"action\"\\s*:\\s*\".*?\"\\s*,\\s*\"action_input\"\\s*:\\s*\".*?\"\\s*}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}"
);
public static String extractModelResponseJson(String text) {
Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher = pattern.matcher(text);
Pattern pattern1 = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher1 = pattern1.matcher(text);

if (matcher.find()) {
return matcher.group(1);
if (matcher1.find()) {
return matcher1.group(1);
} else {
for (String p : MODEL_RESPONSE_PATTERNS) {
Pattern pattern = Pattern.compile(p);
Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group();
}
}
throw new IllegalArgumentException("Model output is invalid");
}
}
Expand Down Expand Up @@ -197,4 +215,64 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = mlAgent.getTools();
if (selectedToolsStr != null) {
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
}
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
for (String tool : selectedTools) {
if (toolNameSpecMap.containsKey(tool)) {
selectedToolSpecs.add(toolNameSpecMap.get(tool));
}
}
toolSpecs = selectedToolSpecs;
}
return toolSpecs;
}

public static void createTools(Map<String, Tool.Factory> toolFactories,
Map<String, String> params,
List<MLToolSpec> toolSpecs,
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap) {
for (MLToolSpec toolSpec : toolSpecs) {
Tool tool = createTool(toolFactories, params, toolSpec);
tools.put(tool.getName(), tool);
toolSpecMap.put(tool.getName(), toolSpec);
}
}

public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
if (params.containsKey(toolName + ".description")) {
tool.setDescription(params.get(toolName + ".description"));
}

return tool;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class MLAgentExecutor implements Executable {

public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
public static final String PARENT_INTERACTION_ID = "interaction_id";
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";
public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit";

Expand Down
Loading

0 comments on commit 86223c7

Please sign in to comment.