From d462322beca60cd0f16cb1d8256e85249b8681d5 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 20 Nov 2023 18:00:47 +0800 Subject: [PATCH 01/12] add abstract retriever class Signed-off-by: zhichao-aws --- .../engine/tools/AbstractRetrieverTool.java | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java new file mode 100644 index 0000000000..35ba8ad1ff --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -0,0 +1,138 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +/** + * Abstract tool supports search paradigms in neural-search plugin. + */ +@Log4j2 +public abstract class AbstractRetrieverTool implements Tool { + protected static String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; + @Getter @Setter + protected String description = DEFAULT_DESCRIPTION; + + protected Client client; + protected final NamedXContentRegistry xContentRegistry; + protected final String index; + protected final String embeddingField; + protected final String[] sourceFields; + protected final String modelId; + protected final Integer docSize; + + protected AbstractRetrieverTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { + this.client = client; + this.xContentRegistry = xContentRegistry; + this.index = index; + this.embeddingField = embeddingField; + this.sourceFields = sourceFields; + this.modelId = modelId; + this.docSize = docSize == null? 2 : docSize; + } + + protected abstract String getQueryBody(String queryText); + + @Override + public void run(Map parameters, ActionListener listener) { + try { + String question = parameters.get("input"); + try { + question = gson.fromJson(question, String.class); + } catch (Exception e) { + //throw new IllegalArgumentException("wrong input"); + } + String query = getQueryBody(question); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.fetchSource(sourceFields, null); + searchSourceBuilder.size(docSize); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (int i = 0; i < hits.length; i++) { + SearchHit hit = hits[i]; + String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map docContent = new HashMap<>(); + docContent.put("_id", hit.getId()); + docContent.put("_source", hit.getSourceAsMap()); + return gson.toJson(docContent); + }); + contextBuilder.append(doc).append("\n"); + } + listener.onResponse((T)gson.toJson(contextBuilder.toString())); + } else { + listener.onResponse((T)""); + } + }, e -> { + log.error("Failed to search index", e); + listener.onFailure(e); + }); + client.search(searchRequest, actionListener); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + String question = parameters.get("input"); + return question != null; + } + + public void setClient(Client client) { + this.client = client; + } + + public static abstract class Factory implements Tool.Factory { + protected Client client; + protected NamedXContentRegistry xContentRegistry; + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} From 52e8ea06dd182da7264ef2a03d92f46adb409779 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 21 Nov 2023 10:42:55 +0800 Subject: [PATCH 02/12] extends the abstract class, add neural sparse tool Signed-off-by: zhichao-aws --- .../engine/tools/AbstractRetrieverTool.java | 16 +-- .../ml/engine/tools/NeuralSparseTool.java | 93 +++++++++++++++ .../ml/engine/tools/VectorDBTool.java | 109 +----------------- 3 files changed, 107 insertions(+), 111 deletions(-) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index 35ba8ad1ff..2c4e5e52d5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -38,14 +38,14 @@ public abstract class AbstractRetrieverTool implements Tool { protected String description = DEFAULT_DESCRIPTION; protected Client client; - protected final NamedXContentRegistry xContentRegistry; - protected final String index; - protected final String embeddingField; - protected final String[] sourceFields; - protected final String modelId; - protected final Integer docSize; - - protected AbstractRetrieverTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { + protected NamedXContentRegistry xContentRegistry; + protected String index; + protected String embeddingField; + protected String[] sourceFields; + protected String modelId; + protected Integer docSize; + + protected AbstractRetrieverTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer docSize, String modelId) { this.client = client; this.xContentRegistry = xContentRegistry; this.index = index; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java new file mode 100644 index 0000000000..810ff7a390 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +/** + * This tool supports neural_sparse search with sparse encoding models and rank_features field. + */@Log4j2 +@ToolAnnotation(NeuralSparseTool.TYPE) +public class NeuralSparseTool extends AbstractRetrieverTool { + public static final String TYPE = "NeuralSparseTool"; + @Setter @Getter + private String name = TYPE; + + @Builder + public NeuralSparseTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { + super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); + } + + @Override + protected String getQueryBody(String queryText){ + return "{\"query\":{\"neural_sparse\":{\"" + embeddingField + "\":{\"query_text\":\"" + queryText + "\",\"model_id\":\"" + + modelId + "\"}}}" + " }"; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private static Factory INSTANCE; + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (NeuralSparseTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + @Override + public NeuralSparseTool create(Map params) { + String index = (String)params.get("index"); + String embeddingField = (String)params.get("embedding_field"); + String[] sourceFields = gson.fromJson((String)params.get("source_field"), String[].class); + String modelId = (String)params.get("model_id"); + Integer docSize = params.containsKey("doc_size")? Integer.parseInt((String)params.get("doc_size")) : 2; + return NeuralSparseTool.builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .build(); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index 94239c2c81..df72547362 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -9,23 +9,10 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; import java.util.Map; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -35,80 +22,22 @@ */ @Log4j2 @ToolAnnotation(VectorDBTool.TYPE) -public class VectorDBTool implements Tool { +public class VectorDBTool extends AbstractRetrieverTool { public static final String TYPE = "VectorDBTool"; @Setter @Getter private String name = TYPE; - private static String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; - @Getter @Setter - private String description = DEFAULT_DESCRIPTION; - - private Client client; - private NamedXContentRegistry xContentRegistry; - private String index; - private String embeddingField; - private String[] sourceFields; - private String modelId; - private Integer docSize ; private Integer k; @Builder public VectorDBTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { - this.client = client; - this.xContentRegistry = xContentRegistry; - this.index = index; - this.embeddingField = embeddingField; - this.sourceFields = sourceFields; - this.modelId = modelId; - this.docSize = docSize == null? 2 : docSize; + super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); this.k = k == null? 10 : k; } @Override - public void run(Map parameters, ActionListener listener) { - try { - String question = parameters.get("input"); - try { - question = gson.fromJson(question, String.class); - } catch (Exception e) { - //throw new IllegalArgumentException("wrong input"); - } - String query = "{\"query\":{\"neural\":{\"" + embeddingField + "\":{\"query_text\":\"" + question + "\",\"model_id\":\"" - + modelId + "\",\"k\":" + k + "}}}" + " }"; - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); - searchSourceBuilder.parseXContent(queryParser); - searchSourceBuilder.fetchSource(sourceFields, null); - searchSourceBuilder.size(docSize); - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); - ActionListener actionListener = ActionListener.wrap(r -> { - SearchHit[] hits = r.getHits().getHits(); - - if (hits != null && hits.length > 0) { - StringBuilder contextBuilder = new StringBuilder(); - for (int i = 0; i < hits.length; i++) { - SearchHit hit = hits[i]; - String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - Map docContent = new HashMap<>(); - docContent.put("_id", hit.getId()); - docContent.put("_source", hit.getSourceAsMap()); - return gson.toJson(docContent); - }); - contextBuilder.append(doc).append("\n"); - } - listener.onResponse((T)gson.toJson(contextBuilder.toString())); - } else { - listener.onResponse((T)""); - } - }, e -> { - log.error("Failed to search index", e); - listener.onFailure(e); - }); - client.search(searchRequest, actionListener); - } catch (IOException e) { - throw new RuntimeException(e); - } + protected String getQueryBody(String queryText){ + return "{\"query\":{\"neural\":{\"" + embeddingField + "\":{\"query_text\":\"" + queryText + "\",\"model_id\":\"" + + modelId + "\",\"k\":" + k + "}}}" + " }"; } @Override @@ -116,11 +45,6 @@ public String getType() { return TYPE; } - @Override - public String getVersion() { - return null; - } - @Override public String getName() { return this.name; @@ -131,23 +55,7 @@ public void setName(String s) { this.name = s; } - @Override - public boolean validate(Map parameters) { - if (parameters == null || parameters.size() == 0) { - return false; - } - String question = parameters.get("input"); - return question != null; - } - - public void setClient(Client client) { - this.client = client; - } - - public static class Factory implements Tool.Factory { - private Client client; - private NamedXContentRegistry xContentRegistry; - + public static class Factory extends AbstractRetrieverTool.Factory { private static Factory INSTANCE; public static Factory getInstance() { if (INSTANCE != null) { @@ -162,11 +70,6 @@ public static Factory getInstance() { } } - public void init(Client client, NamedXContentRegistry xContentRegistry) { - this.client = client; - this.xContentRegistry = xContentRegistry; - } - @Override public VectorDBTool create(Map params) { String index = (String)params.get("index"); From d4294208738fd7090deb57adc52bb0b21f8154f6 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 21 Nov 2023 10:01:22 +0000 Subject: [PATCH 03/12] add register logic Signed-off-by: zhichao-aws --- .../opensearch/ml/engine/tools/NeuralSparseTool.java | 3 ++- .../opensearch/ml/plugin/MachineLearningPlugin.java | 10 +++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 810ff7a390..67ed7d23b7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -19,7 +19,8 @@ /** * This tool supports neural_sparse search with sparse encoding models and rank_features field. - */@Log4j2 + */ +@Log4j2 @ToolAnnotation(NeuralSparseTool.TYPE) public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index d6c90cdb8d..40e087d8ce 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -143,13 +143,7 @@ import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; -import org.opensearch.ml.engine.tools.AgentTool; -import org.opensearch.ml.engine.tools.CatIndexTool; -import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.engine.tools.MathTool; -import org.opensearch.ml.engine.tools.PainlessScriptTool; -import org.opensearch.ml.engine.tools.VectorDBTool; -import org.opensearch.ml.engine.tools.VisualizationsTool; +import org.opensearch.ml.engine.tools.*; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -497,6 +491,7 @@ public Collection createComponents( MLModelTool.Factory.getInstance().init(client); MathTool.Factory.getInstance().init(scriptService); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); + NeuralSparseTool.Factory.getInstance().init(client, xContentRegistry); AgentTool.Factory.getInstance().init(client); CatIndexTool.Factory.getInstance().init(client, clusterService); PainlessScriptTool.Factory.getInstance().init(client, scriptService); @@ -504,6 +499,7 @@ public Collection createComponents( toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(MathTool.TYPE, MathTool.Factory.getInstance()); toolFactories.put(VectorDBTool.TYPE, VectorDBTool.Factory.getInstance()); + toolFactories.put(NeuralSparseTool.TYPE, NeuralSparseTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); toolFactories.put(CatIndexTool.TYPE, CatIndexTool.Factory.getInstance()); toolFactories.put(PainlessScriptTool.TYPE, PainlessScriptTool.Factory.getInstance()); From 6f43ba6204e43196917359de3f61533f8d6ca230 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 23 Nov 2023 03:04:24 +0000 Subject: [PATCH 04/12] tidy Signed-off-by: zhichao-aws --- .../engine/tools/AbstractRetrieverTool.java | 46 +++++++----- .../ml/engine/tools/NeuralSparseTool.java | 71 ++++++++++++------- .../ml/engine/tools/VectorDBTool.java | 27 +++++-- 3 files changed, 96 insertions(+), 48 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index 2c4e5e52d5..d29677a45a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -5,9 +5,14 @@ package org.opensearch.ml.engine.tools; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; @@ -20,13 +25,9 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; -import java.util.Map; - -import static org.opensearch.ml.common.utils.StringUtils.gson; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; /** * Abstract tool supports search paradigms in neural-search plugin. @@ -34,7 +35,8 @@ @Log4j2 public abstract class AbstractRetrieverTool implements Tool { protected static String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; - @Getter @Setter + @Getter + @Setter protected String description = DEFAULT_DESCRIPTION; protected Client client; @@ -45,14 +47,22 @@ public abstract class AbstractRetrieverTool implements Tool { protected String modelId; protected Integer docSize; - protected AbstractRetrieverTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer docSize, String modelId) { + protected AbstractRetrieverTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer docSize, + String modelId + ) { this.client = client; this.xContentRegistry = xContentRegistry; this.index = index; this.embeddingField = embeddingField; this.sourceFields = sourceFields; this.modelId = modelId; - this.docSize = docSize == null? 2 : docSize; + this.docSize = docSize == null ? 2 : docSize; } protected abstract String getQueryBody(String queryText); @@ -64,12 +74,14 @@ public void run(Map parameters, ActionListener listener) try { question = gson.fromJson(question, String.class); } catch (Exception e) { - //throw new IllegalArgumentException("wrong input"); + // throw new IllegalArgumentException("wrong input"); } String query = getQueryBody(question); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); searchSourceBuilder.parseXContent(queryParser); searchSourceBuilder.fetchSource(sourceFields, null); searchSourceBuilder.size(docSize); @@ -89,9 +101,9 @@ public void run(Map parameters, ActionListener listener) }); contextBuilder.append(doc).append("\n"); } - listener.onResponse((T)gson.toJson(contextBuilder.toString())); + listener.onResponse((T) gson.toJson(contextBuilder.toString())); } else { - listener.onResponse((T)""); + listener.onResponse((T) ""); } }, e -> { log.error("Failed to search index", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 67ed7d23b7..8e8ee43a71 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -5,17 +5,18 @@ package org.opensearch.ml.engine.tools; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Map; + import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import java.util.Map; - -import static org.opensearch.ml.common.utils.StringUtils.gson; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; /** * This tool supports neural_sparse search with sparse encoding models and rank_features field. @@ -24,18 +25,34 @@ @ToolAnnotation(NeuralSparseTool.TYPE) public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; - @Setter @Getter + @Setter + @Getter private String name = TYPE; @Builder - public NeuralSparseTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { + public NeuralSparseTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer k, + Integer docSize, + String modelId + ) { super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); } @Override - protected String getQueryBody(String queryText){ - return "{\"query\":{\"neural_sparse\":{\"" + embeddingField + "\":{\"query_text\":\"" + queryText + "\",\"model_id\":\"" - + modelId + "\"}}}" + " }"; + protected String getQueryBody(String queryText) { + return "{\"query\":{\"neural_sparse\":{\"" + + embeddingField + + "\":{\"query_text\":\"" + + queryText + + "\",\"model_id\":\"" + + modelId + + "\"}}}" + + " }"; } @Override @@ -55,6 +72,7 @@ public void setName(String s) { public static class Factory extends AbstractRetrieverTool.Factory { private static Factory INSTANCE; + public static Factory getInstance() { if (INSTANCE != null) { return INSTANCE; @@ -70,20 +88,21 @@ public static Factory getInstance() { @Override public NeuralSparseTool create(Map params) { - String index = (String)params.get("index"); - String embeddingField = (String)params.get("embedding_field"); - String[] sourceFields = gson.fromJson((String)params.get("source_field"), String[].class); - String modelId = (String)params.get("model_id"); - Integer docSize = params.containsKey("doc_size")? Integer.parseInt((String)params.get("doc_size")) : 2; - return NeuralSparseTool.builder() - .client(client) - .xContentRegistry(xContentRegistry) - .index(index) - .embeddingField(embeddingField) - .sourceFields(sourceFields) - .modelId(modelId) - .docSize(docSize) - .build(); + String index = (String) params.get("index"); + String embeddingField = (String) params.get("embedding_field"); + String[] sourceFields = gson.fromJson((String) params.get("source_field"), String[].class); + String modelId = (String) params.get("model_id"); + Integer docSize = params.containsKey("doc_size") ? Integer.parseInt((String) params.get("doc_size")) : 2; + return NeuralSparseTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .build(); } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index af1a91a94e..bdb3b6cf89 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -31,15 +31,32 @@ public class VectorDBTool extends AbstractRetrieverTool { private Integer k; @Builder - public VectorDBTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) { + public VectorDBTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer k, + Integer docSize, + String modelId + ) { super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); - this.k = k == null? 10 : k; + this.k = k == null ? 10 : k; } @Override - protected String getQueryBody(String queryText){ - return "{\"query\":{\"neural\":{\"" + embeddingField + "\":{\"query_text\":\"" + queryText + "\",\"model_id\":\"" - + modelId + "\",\"k\":" + k + "}}}" + " }"; + protected String getQueryBody(String queryText) { + return "{\"query\":{\"neural\":{\"" + + embeddingField + + "\":{\"query_text\":\"" + + queryText + + "\",\"model_id\":\"" + + modelId + + "\",\"k\":" + + k + + "}}}" + + " }"; } @Override From d9b6623d6b6fc03d53a1fb5d45b25c91b0ba266c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 23 Nov 2023 04:07:31 +0000 Subject: [PATCH 05/12] add test class Signed-off-by: zhichao-aws --- .../engine/tools/AbstractRetrieverTool.java | 2 +- .../tools/AbstractRetrieverToolTests.java | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index d29677a45a..b38f4f2123 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -133,7 +133,7 @@ public void setClient(Client client) { this.client = client; } - public static abstract class Factory implements Tool.Factory { + protected static abstract class Factory implements Tool.Factory { protected Client client; protected NamedXContentRegistry xContentRegistry; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java new file mode 100644 index 0000000000..0e4b1e3d5d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -0,0 +1,20 @@ +package org.opensearch.ml.engine.tools; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.Map; + +public class AbstractRetrieverToolTests { + static protected String MOCKED_QUERY = "mock query"; + + @Test + public void testDemo() throws Exception { + AbstractRetrieverTool mockedImpl = Mockito.mock(AbstractRetrieverTool.class, Mockito.CALLS_REAL_METHODS); + when(mockedImpl.getQueryBody(any(String.class))).thenReturn(MOCKED_QUERY); + mockedImpl.run(Map.of()); + } +} From 32d117008d78f9218f3358374133feec70b208a2 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 24 Nov 2023 14:31:49 +0800 Subject: [PATCH 06/12] add test,spotless Apply Signed-off-by: zhichao-aws --- .../engine/tools/AbstractRetrieverTool.java | 26 ++-- .../ml/engine/tools/IndexMappingTool.java | 33 ++-- .../ml/engine/tools/NeuralSparseTool.java | 28 +++- .../ml/engine/tools/VectorDBTool.java | 28 +++- .../tools/AbstractRetrieverToolTests.java | 143 +++++++++++++++++- .../engine/tools/IndexMappingToolTests.java | 39 +++-- .../engine/tools/NeuralSparseToolTests.java | 75 +++++++++ .../retrieval_tool_empty_search_response.json | 18 +++ .../tools/retrieval_tool_search_response.json | 35 +++++ 9 files changed, 355 insertions(+), 70 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java create mode 100644 ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_empty_search_response.json create mode 100644 ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index b38f4f2123..d376ae53a7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -13,6 +13,7 @@ import java.util.HashMap; import java.util.Map; +import org.apache.commons.lang3.StringUtils; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; @@ -33,35 +34,33 @@ * Abstract tool supports search paradigms in neural-search plugin. */ @Log4j2 +@Getter +@Setter public abstract class AbstractRetrieverTool implements Tool { - protected static String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; - @Getter - @Setter - protected String description = DEFAULT_DESCRIPTION; + public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String SOURCE_FIELD = "source_field"; + public static final String DOC_SIZE_FIELD = "doc_size"; + protected String description = DEFAULT_DESCRIPTION; protected Client client; protected NamedXContentRegistry xContentRegistry; protected String index; - protected String embeddingField; protected String[] sourceFields; - protected String modelId; protected Integer docSize; protected AbstractRetrieverTool( Client client, NamedXContentRegistry xContentRegistry, String index, - String embeddingField, String[] sourceFields, - Integer docSize, - String modelId + Integer docSize ) { this.client = client; this.xContentRegistry = xContentRegistry; this.index = index; - this.embeddingField = embeddingField; this.sourceFields = sourceFields; - this.modelId = modelId; this.docSize = docSize == null ? 2 : docSize; } @@ -70,13 +69,16 @@ protected AbstractRetrieverTool( @Override public void run(Map parameters, ActionListener listener) { try { - String question = parameters.get("input"); + String question = parameters.get(INPUT_FIELD); try { question = gson.fromJson(question, String.class); } catch (Exception e) { // throw new IllegalArgumentException("wrong input"); } String query = getQueryBody(question); + if (StringUtils.isBlank(query)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); XContentParser queryParser = XContentType.JSON diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java index d36e590cb4..99a7955cd0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java @@ -5,8 +5,14 @@ package org.opensearch.ml.engine.tools; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + import org.apache.logging.log4j.util.Strings; import org.opensearch.action.admin.indices.get.GetIndexRequest; import org.opensearch.action.admin.indices.get.GetIndexResponse; @@ -20,13 +26,9 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import lombok.Getter; +import lombok.Setter; @ToolAnnotation(IndexMappingTool.NAME) public class IndexMappingTool implements Tool { @@ -74,7 +76,7 @@ public void run(Map parameters, ActionListener listener) listener.onResponse(empty); return; } - + final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); @@ -96,16 +98,16 @@ public void onResponse(GetIndexResponse getIndexResponse) { StringBuilder sb = new StringBuilder(); for (String index : getIndexResponse.indices()) { sb.append("index: ").append(index).append("\n\n"); - + MappingMetadata mapping = getIndexResponse.mappings().get(index); if (mapping != null) { sb.append("mappings:\n"); - for (Entry entry: mapping.sourceAsMap().entrySet()) { - sb.append(entry.getKey()).append("=").append(entry.getValue()).append('\n'); - } + for (Entry entry : mapping.sourceAsMap().entrySet()) { + sb.append(entry.getKey()).append("=").append(entry.getValue()).append('\n'); + } sb.append("\n\n"); } - + Settings settings = getIndexResponse.settings().get(index); if (settings != null) { sb.append("settings:\n").append(settings.toDelimitedString('\n')).append("\n\n"); @@ -126,7 +128,8 @@ public void onFailure(final Exception e) { } }; - final GetIndexRequest getIndexRequest = new GetIndexRequest().indices(indices) + final GetIndexRequest getIndexRequest = new GetIndexRequest() + .indices(indices) .indicesOptions(indicesOptions) .local(local) .clusterManagerNodeTimeout(clusterManagerNodeTimeout); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 8e8ee43a71..33a3f45fb1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -9,6 +9,7 @@ import java.util.Map; +import org.apache.commons.lang3.StringUtils; import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; @@ -22,12 +23,16 @@ * This tool supports neural_sparse search with sparse encoding models and rank_features field. */ @Log4j2 +@Getter +@Setter @ToolAnnotation(NeuralSparseTool.TYPE) public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; - @Setter - @Getter + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; private String name = TYPE; + private String modelId; + private String embeddingField; @Builder public NeuralSparseTool( @@ -40,11 +45,18 @@ public NeuralSparseTool( Integer docSize, String modelId ) { - super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; } @Override protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } return "{\"query\":{\"neural_sparse\":{\"" + embeddingField + "\":{\"query_text\":\"" @@ -88,11 +100,11 @@ public static Factory getInstance() { @Override public NeuralSparseTool create(Map params) { - String index = (String) params.get("index"); - String embeddingField = (String) params.get("embedding_field"); - String[] sourceFields = gson.fromJson((String) params.get("source_field"), String[].class); - String modelId = (String) params.get("model_id"); - Integer docSize = params.containsKey("doc_size") ? Integer.parseInt((String) params.get("doc_size")) : 2; + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; return NeuralSparseTool .builder() .client(client) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index bdb3b6cf89..468bf8f7d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -9,6 +9,7 @@ import java.util.Map; +import org.apache.commons.lang3.StringUtils; import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; @@ -22,13 +23,17 @@ * This tool supports neural search with embedding models and knn index. */ @Log4j2 +@Getter +@Setter @ToolAnnotation(VectorDBTool.TYPE) public class VectorDBTool extends AbstractRetrieverTool { public static final String TYPE = "VectorDBTool"; - @Setter - @Getter + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; private String name = TYPE; private Integer k; + private String modelId; + private String embeddingField; @Builder public VectorDBTool( @@ -41,12 +46,19 @@ public VectorDBTool( Integer docSize, String modelId ) { - super(client, xContentRegistry, index, embeddingField, sourceFields, docSize, modelId); + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; this.k = k == null ? 10 : k; } @Override protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } return "{\"query\":{\"neural\":{\"" + embeddingField + "\":{\"query_text\":\"" @@ -92,11 +104,11 @@ public static Factory getInstance() { @Override public VectorDBTool create(Map params) { - String index = (String) params.get("index"); - String embeddingField = (String) params.get("embedding_field"); - String[] sourceFields = gson.fromJson((String) params.get("source_field"), String[].class); - String modelId = (String) params.get("model_id"); - Integer docSize = params.containsKey("doc_size") ? Integer.parseInt((String) params.get("doc_size")) : 2; + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; return VectorDBTool .builder() .client(client) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java index 0e4b1e3d5d..3a807e0e81 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -1,20 +1,151 @@ package org.opensearch.ml.engine.tools; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import java.io.InputStream; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchModule; -import java.util.Map; +import lombok.SneakyThrows; public class AbstractRetrieverToolTests { - static protected String MOCKED_QUERY = "mock query"; + static public final String TEST_QUERY = "{\"query\":{\"match_all\":{}}}"; + static public final String TEST_INDEX = "test index"; + static public final String[] TEST_SOURCE_FIELDS = new String[] { "test 1", "test 2" }; + static public final Integer TEST_DOC_SIZE = 3; + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private String mockedSearchResponseString; + private String mockedEmptySearchResponseString; + private AbstractRetrieverTool mockedImpl; + + @Before + @SneakyThrows + public void setup() { + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes()); + } + } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_empty_search_response.json")) { + if (searchResponseIns != null) { + mockedEmptySearchResponseString = new String(searchResponseIns.readAllBytes()); + } + } + + mockedImpl = Mockito + .mock( + AbstractRetrieverTool.class, + Mockito + .withSettings() + .useConstructor(null, TEST_XCONTENT_REGISTRY_FOR_QUERY, TEST_INDEX, TEST_SOURCE_FIELDS, TEST_DOC_SIZE) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + when(mockedImpl.getQueryBody(any(String.class))).thenReturn(TEST_QUERY); + } + + @Test + @SneakyThrows + public void testRunAsyncWithSearchResults() { + Client client = mock(Client.class); + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals( + "{\"_source\":{\"passage_text\":\"Company xyz have a history of 100 years.\"},\"_id\":\"1\"}\n" + + "{\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\"}\n", + gson.fromJson(future.get(), String.class) + ); + } @Test - public void testDemo() throws Exception { - AbstractRetrieverTool mockedImpl = Mockito.mock(AbstractRetrieverTool.class, Mockito.CALLS_REAL_METHODS); - when(mockedImpl.getQueryBody(any(String.class))).thenReturn(MOCKED_QUERY); - mockedImpl.run(Map.of()); + @SneakyThrows + public void testRunAsyncWithEmptyInputQuestionThenThrowException() { + Client client = mock(Client.class); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals("", future.get()); + } + + @Test + @SneakyThrows + public void testRunAsyncWithIllegalQueryThenThrowException() { + Client client = mock(Client.class); + mockedImpl.setClient(client); + + assertThrows( + "[input] is null or empty, can not process it.", + IllegalArgumentException.class, + () -> mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""), null) + ); + + assertThrows( + "[input] is null or empty, can not process it.", + IllegalArgumentException.class, + () -> mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "), null) + ); + + assertThrows( + "[input] is null or empty, can not process it.", + IllegalArgumentException.class, + () -> mockedImpl.run(Map.of("test", "hello world"), null) + ); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java index 1c6c441b1c..9585b0b953 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java @@ -5,14 +5,12 @@ package org.opensearch.ml.engine.tools; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; -import org.opensearch.ml.common.spi.tools.Tool; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.when; import java.util.Arrays; import java.util.Collections; @@ -21,6 +19,11 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -29,13 +32,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.json.JsonXContent; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.when; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.spi.tools.Tool; public class IndexMappingToolTests { @@ -68,7 +67,6 @@ public void setup() { emptyParams = Collections.emptyMap(); } - @Test public void testRunAsyncNoIndexParams() throws Exception { Tool tool = IndexMappingTool.Factory.getInstance().create(Collections.emptyMap()); @@ -80,7 +78,7 @@ public void testRunAsyncNoIndexParams() throws Exception { future.join(); assertEquals("There were no results searching the index parameter [null].", future.get()); } - + @Test public void testRunAsyncNoIndices() throws Exception { Tool tool = IndexMappingTool.Factory.getInstance().create(Collections.emptyMap()); @@ -92,7 +90,7 @@ public void testRunAsyncNoIndices() throws Exception { future.join(); assertEquals("There were no results searching the index parameter [null].", future.get()); } - + @Test public void testRunAsyncNoResults() throws Exception { @SuppressWarnings("unchecked") @@ -164,9 +162,8 @@ public void testRunAsyncIndexMapping() throws Exception { assertTrue(responseList.contains("mappings:")); assertTrue( - responseList.contains( - "mappings={year={full_name=year, mapping={year={type=text}}}, age={full_name=age, mapping={age={type=integer}}}}" - ) + responseList + .contains("mappings={year={full_name=year, mapping={year={type=text}}}, age={full_name=age, mapping={age={type=integer}}}}") ); assertTrue(responseList.contains("settings:")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java new file mode 100644 index 0000000000..fef948bc0c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import lombok.SneakyThrows; + +public class NeuralSparseToolTests { + public static final String TEST_EMBEDDING_FIELD = "test embedding"; + public static final String TEST_MODEL_ID = "123fsd23134"; + private Map params = new HashMap<>(); + + @Before + public void setup() { + params.put(NeuralSparseTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); + params.put(NeuralSparseTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(NeuralSparseTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); + params.put(NeuralSparseTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(NeuralSparseTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + } + + @Test + @SneakyThrows + public void testCreateTool() { + NeuralSparseTool tool = NeuralSparseTool.Factory.getInstance().create(params); + assertEquals(AbstractRetrieverToolTests.TEST_INDEX, tool.getIndex()); + assertEquals(TEST_EMBEDDING_FIELD, tool.getEmbeddingField()); + assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); + assertEquals(TEST_MODEL_ID, tool.getModelId()); + assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals("NeuralSparseTool", tool.getType()); + assertEquals("NeuralSparseTool", tool.getName()); + assertEquals("Use this tool to search data in OpenSearch index.", NeuralSparseTool.Factory.getInstance().getDefaultDescription()); + } + + @Test + @SneakyThrows + public void testGetQueryBody() { + NeuralSparseTool tool = NeuralSparseTool.Factory.getInstance().create(params); + assertEquals( + "{\"query\":{\"neural_sparse\":{\"test embedding\":{\"" + + "query_text\":\"{\"query\":{\"match_all\":{}}}\",\"model_id\":\"123fsd23134\"}}} }", + tool.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithIllegalParams() { + Map illegalParams1 = new HashMap<>(params); + illegalParams1.remove(NeuralSparseTool.MODEL_ID_FIELD); + NeuralSparseTool tool1 = NeuralSparseTool.Factory.getInstance().create(illegalParams1); + assertThrows( + "Parameter [embedding_field] and [model_id] can not be null or empty.", + IllegalArgumentException.class, + () -> tool1.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + + Map illegalParams2 = new HashMap<>(params); + illegalParams1.remove(NeuralSparseTool.EMBEDDING_FIELD); + NeuralSparseTool tool2 = NeuralSparseTool.Factory.getInstance().create(illegalParams1); + assertThrows( + "Parameter [embedding_field] and [model_id] can not be null or empty.", + IllegalArgumentException.class, + () -> tool2.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + } +} diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_empty_search_response.json b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_empty_search_response.json new file mode 100644 index 0000000000..7ca6bfa76f --- /dev/null +++ b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_empty_search_response.json @@ -0,0 +1,18 @@ +{ + "took": 4, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 0, + "relation": "eq" + }, + "max_score": null, + "hits": [] + } +} \ No newline at end of file diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json new file mode 100644 index 0000000000..b4226589dd --- /dev/null +++ b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json @@ -0,0 +1,35 @@ +{ + "took": 201, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 89.2917, + "hits": [ + { + "_index": "hybrid-index", + "_id": "1", + "_score": 89.2917, + "_source": { + "passage_text": "Company xyz have a history of 100 years." + } + }, + { + "_index": "hybrid-index", + "_id": "2", + "_score": 0.10702579, + "_source": { + "passage_text": "the price of the api is 2$ per invokation" + } + } + ] + } +} \ No newline at end of file From 741115c85a885a1ea341975adb73d2d7901dffca Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 27 Nov 2023 16:07:47 +0800 Subject: [PATCH 07/12] fix wrong ut name Signed-off-by: zhichao-aws --- .../opensearch/ml/engine/tools/AbstractRetrieverToolTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java index 3a807e0e81..afef220874 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -99,7 +99,7 @@ public void testRunAsyncWithSearchResults() { @Test @SneakyThrows - public void testRunAsyncWithEmptyInputQuestionThenThrowException() { + public void testRunAsyncWithEmptySearchResponse() { Client client = mock(Client.class); SearchResponse mockedEmptySearchResponse = SearchResponse .fromXContent( From e1282db1dc8dc5c1dea6e9e18f7af18645765514 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 29 Nov 2023 17:20:23 +0800 Subject: [PATCH 08/12] add description Signed-off-by: zhichao-aws --- .../org/opensearch/ml/engine/tools/NeuralSparseTool.java | 7 ++++++- .../java/org/opensearch/ml/engine/tools/VectorDBTool.java | 8 +++++++- .../opensearch/ml/engine/tools/NeuralSparseToolTests.java | 3 +++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 33a3f45fb1..2f5c01b12d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -30,6 +30,7 @@ public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String DESCRIPTION_FIELD = "description"; private String name = TYPE; private String modelId; private String embeddingField; @@ -105,7 +106,7 @@ public NeuralSparseTool create(Map params) { String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); String modelId = (String) params.get(MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; - return NeuralSparseTool + NeuralSparseTool neuralSparseTool = NeuralSparseTool .builder() .client(client) .xContentRegistry(xContentRegistry) @@ -115,6 +116,10 @@ public NeuralSparseTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); + if(params.containsKey(DESCRIPTION_FIELD)){ + neuralSparseTool.setDescription((String) params.get(DESCRIPTION_FIELD)); + } + return neuralSparseTool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index 468bf8f7d1..cea8bd9568 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.Map; +import java.util.Vector; import org.apache.commons.lang3.StringUtils; import org.opensearch.client.Client; @@ -30,6 +31,7 @@ public class VectorDBTool extends AbstractRetrieverTool { public static final String TYPE = "VectorDBTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String DESCRIPTION_FIELD = "description"; private String name = TYPE; private Integer k; private String modelId; @@ -109,7 +111,7 @@ public VectorDBTool create(Map params) { String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); String modelId = (String) params.get(MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; - return VectorDBTool + VectorDBTool vectorDBTool = VectorDBTool .builder() .client(client) .xContentRegistry(xContentRegistry) @@ -119,6 +121,10 @@ public VectorDBTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); + if(params.containsKey(DESCRIPTION_FIELD)){ + vectorDBTool.setDescription((String) params.get(DESCRIPTION_FIELD)); + } + return vectorDBTool; } @Override diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java index fef948bc0c..7254ea712b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java @@ -15,6 +15,7 @@ public class NeuralSparseToolTests { public static final String TEST_EMBEDDING_FIELD = "test embedding"; public static final String TEST_MODEL_ID = "123fsd23134"; + public static final String TEST_DESCRIPTION = "test"; private Map params = new HashMap<>(); @Before @@ -24,6 +25,7 @@ public void setup() { params.put(NeuralSparseTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); params.put(NeuralSparseTool.MODEL_ID_FIELD, TEST_MODEL_ID); params.put(NeuralSparseTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + params.put(NeuralSparseTool.DESCRIPTION_FIELD, TEST_DESCRIPTION); } @Test @@ -35,6 +37,7 @@ public void testCreateTool() { assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); assertEquals(TEST_MODEL_ID, tool.getModelId()); assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals(TEST_DESCRIPTION, tool.getDescription()); assertEquals("NeuralSparseTool", tool.getType()); assertEquals("NeuralSparseTool", tool.getName()); assertEquals("Use this tool to search data in OpenSearch index.", NeuralSparseTool.Factory.getInstance().getDefaultDescription()); From d34505046dd0fa8efa1552da71675761d1788671 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 29 Nov 2023 17:30:49 +0800 Subject: [PATCH 09/12] tidy Signed-off-by: zhichao-aws --- .../java/org/opensearch/ml/engine/tools/NeuralSparseTool.java | 2 +- .../main/java/org/opensearch/ml/engine/tools/VectorDBTool.java | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 2f5c01b12d..b282bfe0b1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -116,7 +116,7 @@ public NeuralSparseTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); - if(params.containsKey(DESCRIPTION_FIELD)){ + if (params.containsKey(DESCRIPTION_FIELD)) { neuralSparseTool.setDescription((String) params.get(DESCRIPTION_FIELD)); } return neuralSparseTool; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index cea8bd9568..c6942ad8fb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -8,7 +8,6 @@ import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.Map; -import java.util.Vector; import org.apache.commons.lang3.StringUtils; import org.opensearch.client.Client; @@ -121,7 +120,7 @@ public VectorDBTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); - if(params.containsKey(DESCRIPTION_FIELD)){ + if (params.containsKey(DESCRIPTION_FIELD)) { vectorDBTool.setDescription((String) params.get(DESCRIPTION_FIELD)); } return vectorDBTool; From dbc48fc98cd9ba5a5845cf49c6cf69efb00788e7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 30 Nov 2023 12:58:08 +0800 Subject: [PATCH 10/12] add _index and _id to retriever tool result; modify ut Signed-off-by: zhichao-aws --- .../org/opensearch/ml/engine/tools/AbstractRetrieverTool.java | 2 ++ .../ml/engine/tools/AbstractRetrieverToolTests.java | 4 ++-- .../ml/engine/tools/retrieval_tool_search_response.json | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index d376ae53a7..587dfeb7f9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -97,7 +97,9 @@ public void run(Map parameters, ActionListener listener) SearchHit hit = hits[i]; String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); docContent.put("_source", hit.getSourceAsMap()); return gson.toJson(docContent); }); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java index afef220874..174a502a82 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -91,8 +91,8 @@ public void testRunAsyncWithSearchResults() { future.join(); assertEquals( - "{\"_source\":{\"passage_text\":\"Company xyz have a history of 100 years.\"},\"_id\":\"1\"}\n" - + "{\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\"}\n", + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" + + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", gson.fromJson(future.get(), String.class) ); } diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json index b4226589dd..7e66dd60e8 100644 --- a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json +++ b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json @@ -19,7 +19,7 @@ "_id": "1", "_score": 89.2917, "_source": { - "passage_text": "Company xyz have a history of 100 years." + "passage_text": "Company test_mock have a history of 100 years." } }, { From ec598c6a9c892b9656b6a930e4e01b603ae5f449 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 30 Nov 2023 13:07:33 +0800 Subject: [PATCH 11/12] tidy Signed-off-by: zhichao-aws --- .../ml/engine/tools/AbstractRetrieverToolTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java index 174a502a82..f5251498da 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -91,8 +91,8 @@ public void testRunAsyncWithSearchResults() { future.join(); assertEquals( - "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" + - "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" + + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", gson.fromJson(future.get(), String.class) ); } From 6909355b715007a49a4c3df7b29077604cd9b273 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 30 Nov 2023 15:04:14 +0800 Subject: [PATCH 12/12] remove set description from tool factory Signed-off-by: zhichao-aws --- .../org/opensearch/ml/engine/tools/NeuralSparseTool.java | 7 +------ .../java/org/opensearch/ml/engine/tools/VectorDBTool.java | 7 +------ .../opensearch/ml/engine/tools/NeuralSparseToolTests.java | 3 --- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index b282bfe0b1..33a3f45fb1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -30,7 +30,6 @@ public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; - public static final String DESCRIPTION_FIELD = "description"; private String name = TYPE; private String modelId; private String embeddingField; @@ -106,7 +105,7 @@ public NeuralSparseTool create(Map params) { String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); String modelId = (String) params.get(MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; - NeuralSparseTool neuralSparseTool = NeuralSparseTool + return NeuralSparseTool .builder() .client(client) .xContentRegistry(xContentRegistry) @@ -116,10 +115,6 @@ public NeuralSparseTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); - if (params.containsKey(DESCRIPTION_FIELD)) { - neuralSparseTool.setDescription((String) params.get(DESCRIPTION_FIELD)); - } - return neuralSparseTool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index c6942ad8fb..468bf8f7d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -30,7 +30,6 @@ public class VectorDBTool extends AbstractRetrieverTool { public static final String TYPE = "VectorDBTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; - public static final String DESCRIPTION_FIELD = "description"; private String name = TYPE; private Integer k; private String modelId; @@ -110,7 +109,7 @@ public VectorDBTool create(Map params) { String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); String modelId = (String) params.get(MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; - VectorDBTool vectorDBTool = VectorDBTool + return VectorDBTool .builder() .client(client) .xContentRegistry(xContentRegistry) @@ -120,10 +119,6 @@ public VectorDBTool create(Map params) { .modelId(modelId) .docSize(docSize) .build(); - if (params.containsKey(DESCRIPTION_FIELD)) { - vectorDBTool.setDescription((String) params.get(DESCRIPTION_FIELD)); - } - return vectorDBTool; } @Override diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java index 7254ea712b..fef948bc0c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/NeuralSparseToolTests.java @@ -15,7 +15,6 @@ public class NeuralSparseToolTests { public static final String TEST_EMBEDDING_FIELD = "test embedding"; public static final String TEST_MODEL_ID = "123fsd23134"; - public static final String TEST_DESCRIPTION = "test"; private Map params = new HashMap<>(); @Before @@ -25,7 +24,6 @@ public void setup() { params.put(NeuralSparseTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); params.put(NeuralSparseTool.MODEL_ID_FIELD, TEST_MODEL_ID); params.put(NeuralSparseTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); - params.put(NeuralSparseTool.DESCRIPTION_FIELD, TEST_DESCRIPTION); } @Test @@ -37,7 +35,6 @@ public void testCreateTool() { assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); assertEquals(TEST_MODEL_ID, tool.getModelId()); assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); - assertEquals(TEST_DESCRIPTION, tool.getDescription()); assertEquals("NeuralSparseTool", tool.getType()); assertEquals("NeuralSparseTool", tool.getName()); assertEquals("Use this tool to search data in OpenSearch index.", NeuralSparseTool.Factory.getInstance().getDefaultDescription());