Skip to content

Commit

Permalink
Save tool output into additional info
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>

Add includeToolputInAgentResponse flag

Signed-off-by: Hailong Cui <[email protected]>

Apply spotless

Signed-off-by: Hailong Cui <[email protected]>

address review comments

Signed-off-by: Hailong Cui <[email protected]>

address review comments

Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Nov 23, 2023
1 parent fb9a2f2 commit 801f917
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -29,6 +33,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.LLMSpec;
Expand All @@ -51,6 +56,7 @@
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.ml.repackage.com.google.common.collect.Lists;

import lombok.Data;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -118,6 +124,11 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
Interaction next = iterator.next();
String question = next.getInput();
String response = next.getResponse();
// As we store the conversation with empty response first and then update when have final answer,
// filter out those in-flight requests when run in parallel
if (Strings.isNullOrEmpty(response)) {
continue;
}
messageList
.add(
ConversationIndexMessage
Expand Down Expand Up @@ -170,7 +181,9 @@ private void runAgent(
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
tool.setName(toolSpec.getName());
if (toolSpec.getName() != null) {
tool.setName(toolSpec.getName());
}

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
Expand Down Expand Up @@ -310,6 +323,7 @@ private void runReAct(
AtomicReference<String> lastAction = new AtomicReference<>();
AtomicReference<String> lastActionInput = new AtomicReference<>();
AtomicReference<String> lastActionResult = new AtomicReference<>();
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();

StepListener<?> lastStepListener = null;
int maxIterations = Integer.parseInt(maxIteration) * 2;
Expand Down Expand Up @@ -390,7 +404,7 @@ private void runReAct(
.getMemoryManager()
.updateInteraction(
r.getId(),
ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1),
ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo),
ActionListener.<UpdateResponse>wrap(updateResponse -> {
log.info("Updated final answer into interaction id: {}", r.getId());
log.info("Final answer: {}", finalAnswer1);
Expand All @@ -407,19 +421,14 @@ private void runReAct(
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
Map<String, Object> additionalInfoMap = new HashMap<>(additionalInfo);
additionalInfoMap.put("response", finalAnswer);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Arrays
.asList(
ModelTensor
.builder()
.name("response")
.dataAsMap(ImmutableMap.of("response", finalAnswer))
.build()
)
Arrays.asList(ModelTensor.builder().name("response").dataAsMap(additionalInfoMap).build())
)
.build()
);
Expand Down Expand Up @@ -548,6 +557,21 @@ private void runReAct(
}
} else {
Object result = output;
Tool tool = tools.get(lastAction.get());
if (tool != null && tool.includeOutputInAgentResponse()) {
String outputString = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));

String toolOutputKey = String.format("%s.output", tool.getType());
if (additionalInfo.get(toolOutputKey) != null) {
List<String> list = (List<String>) additionalInfo.get(toolOutputKey);
list.add(outputString);
} else {
additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString));
}

}
modelTensors
.add(
ModelTensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.agent.MLToolSpec.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;

import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -45,11 +47,13 @@ public class VisualizationsTool implements Tool {
private final Client client;
@Getter
private final String index;
private final boolean includeOutputInAgentResponse;

@Builder
public VisualizationsTool(Client client, String index) {
public VisualizationsTool(Client client, String index, boolean includeOutputInAgentResponse) {
this.client = client;
this.index = index;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
}

@Override
Expand Down Expand Up @@ -138,6 +142,11 @@ public boolean validate(Map<String, String> parameters) {
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input"));
}

@Override
public boolean includeOutputInAgentResponse() {
return this.includeOutputInAgentResponse;
}

public static class Factory implements Tool.Factory<VisualizationsTool> {
private Client client;

Expand All @@ -163,7 +172,14 @@ public void init(Client client) {
@Override
public VisualizationsTool create(Map<String, Object> params) {
String index = params.get("index") == null ? ".kibana" : (String) params.get("index");
return VisualizationsTool.builder().client(client).index(index).build();
boolean includeOutputInAgentResponse = params.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE) != null
&& Boolean.parseBoolean((String) params.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE));
return VisualizationsTool
.builder()
.client(client)
.index(index)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ default boolean useOriginalInput() {
return false;
}

/**
* Whether we should include tool's output in agent response
* @return true/false
*/
default boolean includeOutputInAgentResponse() {
return false;
}

/**
* Tool factory which can create instance of {@link Tool}.
* @param <T> The subclass this factory produces
Expand Down

0 comments on commit 801f917

Please sign in to comment.