From 45af77894c2237fefeaae74140202b5ac6f21b7f Mon Sep 17 00:00:00 2001 From: Mingshi Liu <113382730+mingshl@users.noreply.github.com> Date: Wed, 13 Dec 2023 10:22:46 -0800 Subject: [PATCH] Add RAGTool (#1747) * Add RAGTool Signed-off-by: Mingshi Liu * Add static variable and reuse getQueryBody Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- .../engine/tools/AbstractRetrieverTool.java | 5 +- .../opensearch/ml/engine/tools/RAGTool.java | 309 ++++++++++++++++++ .../ml/engine/tools/VectorDBTool.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 2 + 4 files changed, 315 insertions(+), 3 deletions(-) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/RAGTool.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index 587dfeb7f9..a33b46421e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -42,7 +42,8 @@ public abstract class AbstractRetrieverTool implements Tool { public static final String INDEX_FIELD = "index"; public static final String SOURCE_FIELD = "source_field"; public static final String DOC_SIZE_FIELD = "doc_size"; - + public static final int DEFAULT_DOC_SIZE = 2; + public static final int DEFAULT_K = 10; protected String description = DEFAULT_DESCRIPTION; protected Client client; protected NamedXContentRegistry xContentRegistry; @@ -61,7 +62,7 @@ protected AbstractRetrieverTool( this.xContentRegistry = xContentRegistry; this.index = index; this.sourceFields = sourceFields; - this.docSize = docSize == null ? 2 : docSize; + this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; } protected abstract String getQueryBody(String queryText); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/RAGTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/RAGTool.java new file mode 100644 index 0000000000..352604afe6 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/RAGTool.java @@ -0,0 +1,309 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running any ml-commons model. + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(RAGTool.TYPE) +public class RAGTool extends AbstractRetrieverTool { + public static final String TYPE = "RAGTool"; + private static String DEFAULT_DESCRIPTION = "Use this tool to run any model."; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String OUTPUT_FIELD = "output_field"; + private String name = TYPE; + private String description = DEFAULT_DESCRIPTION; + private Client client; + private String modelId; + + private NamedXContentRegistry xContentRegistry; + private String index; + private String embeddingField; + private String[] sourceFields; + private String embeddingModelId; + private Integer docSize; + private Integer k; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + @Builder + public RAGTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer k, + Integer docSize, + String embeddingModelId, + String modelId + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.index = index; + this.embeddingField = embeddingField; + this.sourceFields = sourceFields; + this.embeddingModelId = embeddingModelId; + this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; + this.k = k == null ? DEFAULT_K : k; + this.modelId = modelId; + + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(embeddingModelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + EMBEDDING_MODEL_ID_FIELD + "] can not be null or empty." + ); + } + return "{\"query\":{\"neural\":{\"" + + embeddingField + + "\":{\"query_text\":\"" + + queryText + + "\",\"model_id\":\"" + + embeddingModelId + + "\",\"k\":" + + k + + "}}}" + + " }"; + } + + @Override + public void run(Map parameters, ActionListener listener) { + try { + String question = parameters.get(INPUT_FIELD); + try { + question = gson.fromJson(question, String.class); + } catch (Exception e) { + // throw new IllegalArgumentException("wrong input"); + } + String query = getQueryBody(question); + if (StringUtils.isBlank(query)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.fetchSource(sourceFields, null); + searchSourceBuilder.size(docSize); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + T vectorDBToolOutput; + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (SearchHit hit : hits) { + String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map docContent = new HashMap<>(); + docContent.put("_id", hit.getId()); + docContent.put("_source", hit.getSourceAsMap()); + return gson.toJson(docContent); + }); + contextBuilder.append(doc).append("\n"); + } + vectorDBToolOutput = (T) gson.toJson(contextBuilder.toString()); + } else { + vectorDBToolOutput = (T) ""; + } + + Map tmpParameters = new HashMap<>(); + tmpParameters.putAll(parameters); + + if (vectorDBToolOutput instanceof List + && !((List) vectorDBToolOutput).isEmpty() + && ((List) vectorDBToolOutput).get(0) instanceof ModelTensors) { + ModelTensors tensors = (ModelTensors) ((List) vectorDBToolOutput).get(0); + Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response"); + tmpParameters.put(OUTPUT_FIELD, response + ""); + } else if (vectorDBToolOutput instanceof ModelTensor) { + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(((ModelTensor) vectorDBToolOutput).getDataAsMap()))); + } else { + if (vectorDBToolOutput instanceof String) { + tmpParameters.put(OUTPUT_FIELD, (String) vectorDBToolOutput); + } else { + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(vectorDBToolOutput.toString()))); + } + } + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + modelId, e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search index", e); + listener.onFailure(e); + }); + client.search(searchRequest, actionListener); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + String question = parameters.get(INPUT_FIELD); + return question != null; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private Client client; + private NamedXContentRegistry xContentRegistry; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (RAGTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public RAGTool create(Map params) { + String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD); + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .embeddingModelId(embeddingModelId) + .docSize(docSize) + .modelId(modelId) + .build(); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + + private String toJson(Object value) { + return getString(value); + } + + public static String getString(Object value) { + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + return (String) value; + } else { + return gson.toJson(value); + } + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index 468bf8f7d1..1ff26e92e0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -49,7 +49,7 @@ public VectorDBTool( super(client, xContentRegistry, index, sourceFields, docSize); this.modelId = modelId; this.embeddingField = embeddingField; - this.k = k == null ? 10 : k; + this.k = k == null ? DEFAULT_K : k; } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6556fb639d..94778f34a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -516,6 +516,7 @@ public Collection createComponents( CatIndexTool.Factory.getInstance().init(client, clusterService); PainlessScriptTool.Factory.getInstance().init(client, scriptService); VisualizationsTool.Factory.getInstance().init(client); + RAGTool.Factory.getInstance().init(client, xContentRegistry); SearchAlertsTool.Factory.getInstance().init(client); IndexMappingTool.Factory.getInstance().init(client); @@ -529,6 +530,7 @@ public Collection createComponents( toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); toolFactories.put(SearchAlertsTool.TYPE, SearchAlertsTool.Factory.getInstance()); toolFactories.put(IndexMappingTool.NAME, IndexMappingTool.Factory.getInstance()); + toolFactories.put(RAGTool.TYPE, RAGTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories);