Skip to content

Commit

Permalink
adding tests for all the agent runners (opensearch-project#1783)
Browse files Browse the repository at this point in the history
* adding tests for all the agent runners

Signed-off-by: Dhrubo Saha <[email protected]>

* added more tests

Signed-off-by: Dhrubo Saha <[email protected]>

* adding more tests

Signed-off-by: Dhrubo Saha <[email protected]>

* add tests

Signed-off-by: Dhrubo Saha <[email protected]>

* adding more tests

Signed-off-by: Dhrubo Saha <[email protected]>

* added more files and tests

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* merging PR 1785

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored and austintlee committed Mar 18, 2024
1 parent 27d1cab commit 2abcd32
Show file tree
Hide file tree
Showing 25 changed files with 3,135 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import lombok.extern.log4j.Log4j2;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
Expand All @@ -15,9 +16,11 @@
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
Expand Down Expand Up @@ -203,12 +206,27 @@ public static <T extends Enum<T>, S, I extends Object> S initMLInstance(T type,

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteInputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeInputClassMap, type, in, constructorParamClass);
try {
return init(executeInputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
return init(mlInputClassMap, type, in, constructorParamClass);
}
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteOutputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeOutputClassMap, type, in, constructorParamClass);
try {
return init(executeOutputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
if (in instanceof StreamInput) {
try {
return (S) MLOutput.fromStream((StreamInput) in);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
throw e;
}
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -259,7 +277,7 @@ private static <T, S> S init(Map<T, Class<?>> map, T type,
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
} else if (cause instanceof IllegalArgumentException) {
} else if (cause instanceof IllegalArgumentException) {
throw (IllegalArgumentException) cause;
} else {
log.error("Failed to init instance for type " + type, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,18 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
}
return parameters;
}

public static String toJson(Object value) {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
if (value instanceof String) {
return (String) value;
} else {
return gson.toJson(value);
}
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Client;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
Expand All @@ -33,13 +35,16 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

public class MLCommonsClassLoaderTests {

Expand Down Expand Up @@ -170,6 +175,15 @@ public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

@Test(expected = IllegalArgumentException.class)
public void testConnectorInitializationException() {
// Example initialization parameters for connectors
String initParam1 = "parameter1";

// Initialize the first connector type
MLCommonsClassLoader.initConnector("Connector", new Object[]{initParam1}, String.class);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.Output;
Expand All @@ -16,5 +17,5 @@ public interface Executable {
* @param input input data
* @return execution result
*/
Output execute(Input input) throws ExecuteException;
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand Down Expand Up @@ -152,20 +153,20 @@ public MLOutput trainAndPredict(Input input) {
return trainAndPredictable.trainAndPredict(mlInput);
}

public Output execute(Input input) throws Exception {
public void execute(Input input, ActionListener<Output> listener) throws Exception {
validateInput(input);
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
} else {
Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.io.FileUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

public abstract Output execute(Input input) throws ExecuteException;
public abstract void execute(Input input, ActionListener<Output> listener);

protected Predictor<float[][], ai.djl.modality.Output> getPredictor() {
int currentDevice = nextDevice.getAndIncrement();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
if (parameters.containsKey(EXAMPLES)) {
String examples = parameters.get(EXAMPLES);
List<String> exampleList = gson.fromJson(examples, List.class);
StringBuilder exampleBuilder = new StringBuilder();
exampleBuilder.append("EXAMPLES\n--------\n");
String examplesPrefix = Optional
.ofNullable(parameters.get("examples.prefix"))
.orElse("You should follow and learn from examples defined in <examples>: \n" + "<examples>\n");
String examplesSuffix = Optional.ofNullable(parameters.get("examples.suffix")).orElse("</examples>\n");
exampleBuilder.append(examplesPrefix);

String examplePrefix = Optional.ofNullable(parameters.get("examples.example.prefix")).orElse("<example>\n");
String exampleSuffix = Optional.ofNullable(parameters.get("examples.example.suffix")).orElse("\n</example>\n");
for (String example : exampleList) {
exampleBuilder.append(examplePrefix).append(example).append(exampleSuffix);
}
exampleBuilder.append(examplesSuffix);
examplesMap.put(EXAMPLES, exampleBuilder.toString());
} else {
examplesMap.put(EXAMPLES, "");
}
StringSubstitutor substitutor = new StringSubstitutor(examplesMap, "${parameters.", "}");
return substitutor.replace(prompt);
}

public static String addPrefixSuffixToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> prefixMap = new HashMap<>();
String prefix = parameters.getOrDefault(PROMPT_PREFIX, "");
String suffix = parameters.getOrDefault(PROMPT_SUFFIX, "");
prefixMap.put(PROMPT_PREFIX, prefix);
prefixMap.put(PROMPT_SUFFIX, suffix);
StringSubstitutor substitutor = new StringSubstitutor(prefixMap, "${parameters.", "}");
return substitutor.replace(prompt);
}

public static String addToolsToPrompt(Map<String, Tool> tools, Map<String, String> parameters, List<String> inputTools, String prompt) {
StringBuilder toolsBuilder = new StringBuilder();
StringBuilder toolNamesBuilder = new StringBuilder();

String toolsPrefix = Optional
.ofNullable(parameters.get("agent.tools.prefix"))
.orElse("You have access to the following tools defined in <tools>: \n" + "<tools>\n");
String toolsSuffix = Optional.ofNullable(parameters.get("agent.tools.suffix")).orElse("</tools>\n");
String toolPrefix = Optional.ofNullable(parameters.get("agent.tools.tool.prefix")).orElse("<tool>\n");
String toolSuffix = Optional.ofNullable(parameters.get("agent.tools.tool.suffix")).orElse("\n</tool>\n");
toolsBuilder.append(toolsPrefix);
for (String toolName : inputTools) {
if (!tools.containsKey(toolName)) {
throw new IllegalArgumentException("Tool [" + toolName + "] not registered for model");
}
toolsBuilder.append(toolPrefix).append(toolName).append(": ").append(tools.get(toolName).getDescription()).append(toolSuffix);
toolNamesBuilder.append(toolName).append(", ");
}
toolsBuilder.append(toolsSuffix);
Map<String, String> toolsPromptMap = new HashMap<>();
toolsPromptMap.put(TOOL_DESCRIPTIONS, toolsBuilder.toString());
toolsPromptMap.put(TOOL_NAMES, toolNamesBuilder.substring(0, toolNamesBuilder.length() - 1));

if (parameters.containsKey(TOOL_DESCRIPTIONS)) {
toolsPromptMap.put(TOOL_DESCRIPTIONS, parameters.get(TOOL_DESCRIPTIONS));
}
if (parameters.containsKey(TOOL_NAMES)) {
toolsPromptMap.put(TOOL_NAMES, parameters.get(TOOL_NAMES));
}
StringSubstitutor substitutor = new StringSubstitutor(toolsPromptMap, "${parameters.", "}");
return substitutor.replace(prompt);
}

public static String addIndicesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> indicesMap = new HashMap<>();
if (parameters.containsKey(OS_INDICES)) {
String indices = parameters.get(OS_INDICES);
List<String> indicesList = gson.fromJson(indices, List.class);
StringBuilder indicesBuilder = new StringBuilder();
String indicesPrefix = Optional
.ofNullable(parameters.get("opensearch_indices.prefix"))
.orElse("You have access to the following OpenSearch Index defined in <opensearch_indexes>: \n" + "<opensearch_indexes>\n");
String indicesSuffix = Optional.ofNullable(parameters.get("opensearch_indices.suffix")).orElse("</opensearch_indexes>\n");
String indexPrefix = Optional.ofNullable(parameters.get("opensearch_indices.index.prefix")).orElse("<index>\n");
String indexSuffix = Optional.ofNullable(parameters.get("opensearch_indices.index.suffix")).orElse("\n</index>\n");
indicesBuilder.append(indicesPrefix);
for (String e : indicesList) {
indicesBuilder.append(indexPrefix).append(e).append(indexSuffix);
}
indicesBuilder.append(indicesSuffix);
indicesMap.put(OS_INDICES, indicesBuilder.toString());
} else {
indicesMap.put(OS_INDICES, "");
}
StringSubstitutor substitutor = new StringSubstitutor(indicesMap, "${parameters.", "}");
return substitutor.replace(prompt);
}

public static String addChatHistoryToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> chatHistoryMap = new HashMap<>();
String chatHistory = parameters.getOrDefault(CHAT_HISTORY, "");
chatHistoryMap.put(CHAT_HISTORY, chatHistory);
parameters.put(CHAT_HISTORY, chatHistory);
StringSubstitutor substitutor = new StringSubstitutor(chatHistoryMap, "${parameters.", "}");
return substitutor.replace(prompt);
}

public static String addContextToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> contextMap = new HashMap<>();
contextMap.put(CONTEXT, parameters.getOrDefault(CONTEXT, ""));
parameters.put(CONTEXT, contextMap.get(CONTEXT));
if (contextMap.size() > 0) {
StringSubstitutor substitutor = new StringSubstitutor(contextMap, "${parameters.", "}");
return substitutor.replace(prompt);
}
return prompt;
}
}
Loading

0 comments on commit 2abcd32

Please sign in to comment.