Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 11, 2024
1 parent 07d08cb commit 40fa88c
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -28,13 +27,20 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
public static final String PROMPT_PREFIX = "prompt.prefix";
public static final String PROMPT_SUFFIX = "prompt.suffix";
public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
public static final String TOOL_RESPONSE = "prompt.tool_response";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
if (parameters.containsKey(EXAMPLES)) {
Expand Down Expand Up @@ -197,4 +203,53 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = mlAgent.getTools();
if (selectedToolsStr != null) {
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
}
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
for (String tool : selectedTools) {
if (toolNameSpecMap.containsKey(tool)) {
selectedToolSpecs.add(toolNameSpecMap.get(tool));
}
}
toolSpecs = selectedToolSpecs;
}
return toolSpecs;
}

public static void createTools(Map<String, Tool.Factory> toolFactories,
Map<String, String> params,
List<MLToolSpec> toolSpecs,
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap) {
for (MLToolSpec toolSpec : toolSpecs) {
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}

tools.put(toolName, tool);
toolSpecMap.put(toolName, toolSpec);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class MLAgentExecutor implements Executable {

public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
public static final String PARENT_INTERACTION_ID = "interaction_id";
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";
public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

Expand Down Expand Up @@ -73,10 +79,8 @@
public class MLChatAgentRunner implements MLAgentRunner {

public static final String SESSION_ID = "session_id";
public static final String PROMPT_PREFIX = "prompt_prefix";
public static final String LLM_TOOL_PROMPT_PREFIX = "LanguageModelTool.prompt_prefix";
public static final String LLM_TOOL_PROMPT_SUFFIX = "LanguageModelTool.prompt_suffix";
public static final String PROMPT_SUFFIX = "prompt_suffix";
public static final String TOOLS = "tools";
public static final String TOOL_DESCRIPTIONS = "tool_descriptions";
public static final String TOOL_NAMES = "tool_names";
Expand Down Expand Up @@ -121,6 +125,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje

ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
//TODO: call runAgent directly if messageHistoryLimit == 0
memory.getMessages(ActionListener.<List<Interaction>>wrap(r -> {
List<Message> messageList = new ArrayList<>();
for (Interaction next : r) {
Expand Down Expand Up @@ -160,34 +165,10 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
}

private void runAgent(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, Memory memory, String sessionId) {
List<MLToolSpec> toolSpecs = mlAgent.getTools();
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, params);
Map<String, Tool> tools = new HashMap<>();
Map<String, MLToolSpec> toolSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
Map<String, String> toolParams = new HashMap<>();
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
toolParams.putAll(toolSpec.getParameters());
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
if (key.startsWith(toolSpec.getType() + ".")) {
executeParams.put(key.replace(toolSpec.getType() + ".", ""), params.get(key));
}
}
log.info("Fetching tool for type: " + toolSpec.getType());
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
if (toolSpec.getName() != null) {
tool.setName(toolSpec.getName());
}

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
String toolName = Optional.ofNullable(tool.getName()).orElse(toolSpec.getType());
tools.put(toolName, tool);
toolSpecMap.put(toolName, toolSpec);
}
createTools(toolFactories, params, toolSpecs, tools, toolSpecMap);

runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, listener);
}
Expand Down Expand Up @@ -229,30 +210,25 @@ private void runReAct(
);
}

String prompt = parameters.get(PROMPT);
if (prompt == null) {
prompt = PromptTemplate.PROMPT_TEMPLATE;
}
String promptPrefix = parameters.getOrDefault("prompt.prefix", PromptTemplate.PROMPT_TEMPLATE_PREFIX);
tmpParameters.put("prompt.prefix", promptPrefix);
String prompt = parameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE);
String promptPrefix = parameters.getOrDefault(PROMPT_PREFIX, PromptTemplate.PROMPT_TEMPLATE_PREFIX);
tmpParameters.put(PROMPT_PREFIX, promptPrefix);

String promptSuffix = parameters.getOrDefault("prompt.suffix", PromptTemplate.PROMPT_TEMPLATE_SUFFIX);
tmpParameters.put("prompt.suffix", promptSuffix);
String promptSuffix = parameters.getOrDefault(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX);
tmpParameters.put(PROMPT_SUFFIX, promptSuffix);

String promptFormatInstruction = parameters.getOrDefault("prompt.format_instruction", PromptTemplate.PROMPT_FORMAT_INSTRUCTION);
tmpParameters.put("prompt.format_instruction", promptFormatInstruction);
if (!tmpParameters.containsKey("prompt.tool_response")) {
tmpParameters.put("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE);
}
String promptToolResponse = parameters.getOrDefault("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE);
tmpParameters.put("prompt.tool_response", promptToolResponse);
String promptFormatInstruction = parameters.getOrDefault(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION);
tmpParameters.put(RESPONSE_FORMAT_INSTRUCTION, promptFormatInstruction);

String promptToolResponse = parameters.getOrDefault(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE);
tmpParameters.put(TOOL_RESPONSE, promptToolResponse);

StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}");
prompt = promptSubstitutor.replace(prompt);

final List<String> inputTools = new ArrayList<>();
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getType());
String toolName = entry.getValue().getName();
inputTools.add(toolName);
}

Expand Down Expand Up @@ -325,16 +301,18 @@ private void runReAct(
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
String thoughtResponse = "";
String llmReasoningResponse = null;
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
String response = (String) dataAsMap.get("response");
String thoughtResponse = extractModelResponseJson(response);
llmReasoningResponse = (String) dataAsMap.get("response");
thoughtResponse = extractModelResponseJson(llmReasoningResponse);
dataAsMap = gson.fromJson(thoughtResponse, Map.class);
}
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
if (!dataAsMap.containsKey("thought")) {
if (!dataAsMap.containsKey("thought")) {//TODO: check if we can remove this if block
String response = (String) dataAsMap.get("response");
Pattern pattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL);
Matcher matcher = pattern.matcher(response);
Expand Down Expand Up @@ -370,12 +348,12 @@ private void runReAct(
// TODO: check if verbose
modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs());

if (conversationIndexMemory != null) {
if (conversationIndexMemory != null && finalAnswer == null) {
ConversationIndexMessage msgTemp = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(memory.getType())
.question(question)
.response(thought)
.response(llmReasoningResponse)
.finalAnswer(false)
.sessionId(sessionId)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;

Expand Down Expand Up @@ -62,7 +63,7 @@
public class MLConversationalFlowAgentRunner implements MLAgentRunner {

public static final String CHAT_HISTORY = "chat_history";
public static final String SELECTED_TOOLS = "selected_tools";

private Client client;
private Settings settings;
private ClusterService clusterService;
Expand Down Expand Up @@ -156,8 +157,7 @@ private void runAgent(
Map<String, String> firstToolExecuteParams = null;
StepListener<Object> previousStepListener = null;
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, selectedToolsStr);
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, params);

if (toolSpecs == null || toolSpecs.size() == 0) {
listener.onFailure(new IllegalArgumentException("no tool configured"));
Expand Down Expand Up @@ -231,25 +231,6 @@ private void runAgent(
}
}

private static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, String selectedToolsStr) {
List<MLToolSpec> toolSpecs = mlAgent.getTools();
if (selectedToolsStr != null) {
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
}
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
for (String tool : selectedTools) {
if (toolNameSpecMap.containsKey(tool)) {
selectedToolSpecs.add(toolNameSpecMap.get(tool));
}
}
toolSpecs = selectedToolSpecs;
}
return toolSpecs;
}

private void processOutput(
Map<String, String> params,
ActionListener<Object> listener,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.algorithms.agent;

import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -73,7 +74,7 @@ public MLFlowAgentRunner(

@Override
public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
List<MLToolSpec> toolSpecs = mlAgent.getTools();
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, params);
StepListener<Object> firstStepListener = null;
Tool firstTool = null;
List<ModelTensor> flowAgentOutput = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void createInteraction(
* @param actionListener get all the final interactions that are not traces
*/
public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener<List<Interaction>> actionListener) {
Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1.");
Preconditions.checkArgument(lastNInteraction > 0, "History message size must be at least 1.");
log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;

import java.util.Arrays;
import java.util.HashMap;
Expand Down

0 comments on commit 40fa88c

Please sign in to comment.