forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Neural sparse tool, do refactor using AbstractRetrievalTool (open…
…search-project#1686) * add abstract retriever class Signed-off-by: zhichao-aws <[email protected]> * extends the abstract class, add neural sparse tool Signed-off-by: zhichao-aws <[email protected]> * add register logic Signed-off-by: zhichao-aws <[email protected]> * tidy Signed-off-by: zhichao-aws <[email protected]> * add test class Signed-off-by: zhichao-aws <[email protected]> * add test,spotless Apply Signed-off-by: zhichao-aws <[email protected]> * fix wrong ut name Signed-off-by: zhichao-aws <[email protected]> * add description Signed-off-by: zhichao-aws <[email protected]> * tidy Signed-off-by: zhichao-aws <[email protected]> * add _index and _id to retriever tool result; modify ut Signed-off-by: zhichao-aws <[email protected]> * tidy Signed-off-by: zhichao-aws <[email protected]> * remove set description from tool factory Signed-off-by: zhichao-aws <[email protected]> --------- Signed-off-by: zhichao-aws <[email protected]>
- Loading branch information
1 parent
fee3c7b
commit e37f121
Showing
7 changed files
with
599 additions
and
0 deletions.
There are no files selected for viewing
154 changes: 154 additions & 0 deletions
154
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.io.IOException; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Abstract tool supports search paradigms in neural-search plugin. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
public abstract class AbstractRetrieverTool implements Tool { | ||
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; | ||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String SOURCE_FIELD = "source_field"; | ||
public static final String DOC_SIZE_FIELD = "doc_size"; | ||
|
||
protected String description = DEFAULT_DESCRIPTION; | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
protected String index; | ||
protected String[] sourceFields; | ||
protected Integer docSize; | ||
|
||
protected AbstractRetrieverTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String[] sourceFields, | ||
Integer docSize | ||
) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.index = index; | ||
this.sourceFields = sourceFields; | ||
this.docSize = docSize == null ? 2 : docSize; | ||
} | ||
|
||
protected abstract String getQueryBody(String queryText); | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
try { | ||
String question = parameters.get(INPUT_FIELD); | ||
try { | ||
question = gson.fromJson(question, String.class); | ||
} catch (Exception e) { | ||
// throw new IllegalArgumentException("wrong input"); | ||
} | ||
String query = getQueryBody(question); | ||
if (StringUtils.isBlank(query)) { | ||
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); | ||
} | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON | ||
.xContent() | ||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
searchSourceBuilder.fetchSource(sourceFields, null); | ||
searchSourceBuilder.size(docSize); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (int i = 0; i < hits.length; i++) { | ||
SearchHit hit = hits[i]; | ||
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> { | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
return gson.toJson(docContent); | ||
}); | ||
contextBuilder.append(doc).append("\n"); | ||
} | ||
listener.onResponse((T) gson.toJson(contextBuilder.toString())); | ||
} else { | ||
listener.onResponse((T) ""); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
}); | ||
client.search(searchRequest, actionListener); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
@Override | ||
public String getVersion() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
if (parameters == null || parameters.size() == 0) { | ||
return false; | ||
} | ||
String question = parameters.get("input"); | ||
return question != null; | ||
} | ||
|
||
public void setClient(Client client) { | ||
this.client = client; | ||
} | ||
|
||
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> { | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
125 changes: 125 additions & 0 deletions
125
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* This tool supports neural_sparse search with sparse encoding models and rank_features field. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
@ToolAnnotation(NeuralSparseTool.TYPE) | ||
public class NeuralSparseTool extends AbstractRetrieverTool { | ||
public static final String TYPE = "NeuralSparseTool"; | ||
public static final String MODEL_ID_FIELD = "model_id"; | ||
public static final String EMBEDDING_FIELD = "embedding_field"; | ||
private String name = TYPE; | ||
private String modelId; | ||
private String embeddingField; | ||
|
||
@Builder | ||
public NeuralSparseTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String embeddingField, | ||
String[] sourceFields, | ||
Integer k, | ||
Integer docSize, | ||
String modelId | ||
) { | ||
super(client, xContentRegistry, index, sourceFields, docSize); | ||
this.modelId = modelId; | ||
this.embeddingField = embeddingField; | ||
} | ||
|
||
@Override | ||
protected String getQueryBody(String queryText) { | ||
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { | ||
throw new IllegalArgumentException( | ||
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." | ||
); | ||
} | ||
return "{\"query\":{\"neural_sparse\":{\"" | ||
+ embeddingField | ||
+ "\":{\"query_text\":\"" | ||
+ queryText | ||
+ "\",\"model_id\":\"" | ||
+ modelId | ||
+ "\"}}}" | ||
+ " }"; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return this.name; | ||
} | ||
|
||
@Override | ||
public void setName(String s) { | ||
this.name = s; | ||
} | ||
|
||
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseTool> { | ||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (NeuralSparseTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
@Override | ||
public NeuralSparseTool create(Map<String, Object> params) { | ||
String index = (String) params.get(INDEX_FIELD); | ||
String embeddingField = (String) params.get(EMBEDDING_FIELD); | ||
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); | ||
String modelId = (String) params.get(MODEL_ID_FIELD); | ||
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; | ||
return NeuralSparseTool | ||
.builder() | ||
.client(client) | ||
.xContentRegistry(xContentRegistry) | ||
.index(index) | ||
.embeddingField(embeddingField) | ||
.sourceFields(sourceFields) | ||
.modelId(modelId) | ||
.docSize(docSize) | ||
.build(); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
Oops, something went wrong.