diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index bdccaa86a7..e51e71f678 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -557,13 +557,13 @@ private void runReAct( } } else { Object result = output; - Tool tool = tools.get(lastAction.get()); - if (tool != null && tool.includeOutputInAgentResponse()) { + MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); + if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { String outputString = output instanceof String ? (String) output : AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); - String toolOutputKey = String.format("%s.output", tool.getType()); + String toolOutputKey = String.format("%s.output", toolSpec.getType()); if (additionalInfo.get(toolOutputKey) != null) { List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java index 495d42dc50..91c3237ac8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java @@ -5,8 +5,6 @@ package org.opensearch.ml.engine.tools; -import static org.opensearch.ml.common.agent.MLToolSpec.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; - import java.util.Arrays; import java.util.Map; import java.util.Optional; @@ -47,13 +45,11 @@ public class VisualizationsTool implements Tool { private final Client client; @Getter private final String index; - private final boolean includeOutputInAgentResponse; @Builder - public VisualizationsTool(Client client, String index, boolean includeOutputInAgentResponse) { + public VisualizationsTool(Client client, String index) { this.client = client; this.index = index; - this.includeOutputInAgentResponse = includeOutputInAgentResponse; } @Override @@ -142,11 +138,6 @@ public boolean validate(Map parameters) { return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); } - @Override - public boolean includeOutputInAgentResponse() { - return this.includeOutputInAgentResponse; - } - public static class Factory implements Tool.Factory { private Client client; @@ -172,14 +163,7 @@ public void init(Client client) { @Override public VisualizationsTool create(Map params) { String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); - boolean includeOutputInAgentResponse = params.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE) != null - && Boolean.parseBoolean((String) params.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE)); - return VisualizationsTool - .builder() - .client(client) - .index(index) - .includeOutputInAgentResponse(includeOutputInAgentResponse) - .build(); + return VisualizationsTool.builder().client(client).index(index).build(); } @Override