Skip to content

Commit

Permalink
test and clean code
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 12, 2024
1 parent d65eaa4 commit 841973b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,6 @@ public static String extractModelResponseJson(String text) {
return matcher.group();
}
}
//// Pattern pattern2 = Pattern.compile("\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\}");
// Pattern pattern2 = Pattern.compile("\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}");
// Pattern pattern3 = Pattern.compile("\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}");
//
//// Pattern pattern2 = Pattern.compile("\\{\\s*(\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?|\"thought\":.*?\\s*,\\s*\"final_answer\":.*?)\\}");
// Matcher matcher2 = pattern2.matcher(text);
// Matcher matcher3 = pattern3.matcher(text);
// // Find the JSON content
// if (matcher2.find()) {
// return matcher2.group();
// }
// if (matcher3.find()) {
// return matcher3.group();
// }
throw new IllegalArgumentException("Model output is invalid");
}
}
Expand Down Expand Up @@ -256,26 +242,37 @@ public static void createTools(Map<String, Tool.Factory> toolFactories,
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap) {
for (MLToolSpec toolSpec : toolSpecs) {
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);
Tool tool = createTool(toolFactories, params, toolSpec);
tools.put(tool.getName(), tool);
toolSpecMap.put(tool.getName(), toolSpec);
}
}

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

tools.put(toolName, tool);
toolSpecMap.put(toolName, toolSpec);
if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
if (params.containsKey(toolName + ".description")) {
tool.setDescription(params.get(toolName + ".description"));
}

return tool;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -66,7 +64,6 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.ml.repackage.com.google.common.collect.Lists;

Expand Down Expand Up @@ -262,16 +259,8 @@ private void runReAct(
StepListener<?> nextStepListener = new StepListener<>();

lastStepListener.whenComplete(output -> {





//////////////////////////////////////////////////////////////////////////////////////////
// start
//////////////////////////////////////////////////////////////////////////////////////////
StringBuilder sessionMsgAnswerBuilder = new StringBuilder();
if (finalI % 2 == 0) {// Reasoning which tool to use
if (finalI % 2 == 0) {
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
Expand All @@ -284,7 +273,6 @@ private void runReAct(
} catch (IllegalArgumentException e) {
thoughtResponse = llmReasoningResponse;
finalAnswer = llmReasoningResponse;
System.out.println("0000000000 ylwudddebug1: get final answer directly : " + finalAnswer);
}
if (isJson(thoughtResponse)) {
dataAsMap = gson.fromJson(thoughtResponse, Map.class);
Expand Down Expand Up @@ -401,7 +389,7 @@ private void runReAct(
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());
}
} else { // run tool
} else {
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = outputToOutputString(output);
Expand Down Expand Up @@ -430,7 +418,7 @@ private void runReAct(
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
sessionMsgAnswerBuilder.append(outputToOutputString(output));
cotModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build())).build());

//client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
Expand All @@ -448,9 +436,6 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
}
}
//////////////////////////////////////////////////////////////////////////////////////////
// end
//////////////////////////////////////////////////////////////////////////////////////////
}, e -> {
log.error("Failed to run chat agent", e);
listener.onFailure(e);
Expand Down Expand Up @@ -497,25 +482,6 @@ private void saveMessage(
}
}

private GroupedActionListener<ActionResponse> createGroupedListener(final int size, final ActionListener<Boolean> listener) {
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
@Override
public void onResponse(final Collection<ActionResponse> responses) {
CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class);
log.info("saved message with interaction id: {}", createInteractionResponse.getId());
UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class);
log.info("Updated final answer into interaction id: {}", updateResponse.getId());

listener.onResponse(true);
}

@Override
public void onFailure(final Exception e) {
listener.onFailure(e);
}
}, size);
}

@SuppressWarnings("unchecked")
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
return (A) responses.stream().filter(c::isInstance).findFirst().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID;
import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
Expand Down Expand Up @@ -173,7 +174,7 @@ private void runAgent(
for (int i = 0; i <= toolSpecs.size(); i++) {
if (i == 0) {
MLToolSpec toolSpec = toolSpecs.get(i);
Tool tool = createTool(toolSpec);
Tool tool = createTool(toolFactories, params, toolSpec);
firstStepListener = new StepListener<>();
previousStepListener = firstStepListener;
firstTool = tool;
Expand Down Expand Up @@ -310,7 +311,7 @@ private void processOutput(

private void runNextStep(Map<String, String> params, List<MLToolSpec> toolSpecs, int finalI, StepListener<Object> nextStepListener) {
MLToolSpec toolSpec = toolSpecs.get(finalI);
Tool tool = createTool(toolSpec);
Tool tool = createTool(toolFactories, params, toolSpec);
if (finalI < toolSpecs.size()) {
tool.run(getToolExecuteParams(toolSpec, params), nextStepListener);
}
Expand Down Expand Up @@ -384,26 +385,6 @@ String parseResponse(Object output) throws IOException {
}
}

@VisibleForTesting
Tool createTool(MLToolSpec toolSpec) {
Map<String, String> toolParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
toolParams.putAll(toolSpec.getParameters());
}
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
if (toolSpec.getName() != null) {
tool.setName(toolSpec.getName());
}

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
return tool;
}

@VisibleForTesting
Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params) {
Map<String, String> executeParams = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,4 @@ public class PromptTemplate {
public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nHuman: follow RESPONSE FORMAT INSTRUCTIONS\n\nAssistant:";
public static final String PROMPT_TEMPLATE_TOOL_RESPONSE =
"Assistant:\n---------------------\n${parameters.llm_tool_selection_response}\n\nHuman: TOOL RESPONSE of ${parameters.tool_name}: \n---------------------\nTool input:\n${parameters.tool_input}\n\nTool output:\n${parameters.observation}\n\n";
public static final String PROMPT_TEMPLATE_ASK_AGAIN = "\n\nUSER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names.";
//public static final String s = "USER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.";
// Human: follow RESPONSE FORMAT INSTRUCTIONS
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
StringBuilder sb = new StringBuilder(
// Currently using c.value which is short header matching _cat/indices
// May prefer to use c.attr.get("desc") for full description
table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining(", ", "", "\n"))
table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining(",", "", "\n"))
);
for (List<Cell> row : table.getRows()) {
sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining(", ", "", "\n")));
sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining(",", "", "\n")));
}
@SuppressWarnings("unchecked")
T response = (T) sb.toString();
Expand Down

0 comments on commit 841973b

Please sign in to comment.