Skip to content

Commit

Permalink
Add Neural sparse tool, do refactor using AbstractRetrievalTool (open…
Browse files Browse the repository at this point in the history
…search-project#1686)

* add abstract retriever class

Signed-off-by: zhichao-aws <[email protected]>

* extends the abstract class, add neural sparse tool

Signed-off-by: zhichao-aws <[email protected]>

* add register logic

Signed-off-by: zhichao-aws <[email protected]>

* tidy

Signed-off-by: zhichao-aws <[email protected]>

* add test class

Signed-off-by: zhichao-aws <[email protected]>

* add test,spotless Apply

Signed-off-by: zhichao-aws <[email protected]>

* fix wrong ut name

Signed-off-by: zhichao-aws <[email protected]>

* add description

Signed-off-by: zhichao-aws <[email protected]>

* tidy

Signed-off-by: zhichao-aws <[email protected]>

* add _index and _id to retriever tool result; modify ut

Signed-off-by: zhichao-aws <[email protected]>

* tidy

Signed-off-by: zhichao-aws <[email protected]>

* remove set description from tool factory

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored and mingshl committed Dec 19, 2023
1 parent fee3c7b commit e37f121
Show file tree
Hide file tree
Showing 7 changed files with 599 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Abstract tool supports search paradigms in neural-search plugin.
*/
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";

protected String description = DEFAULT_DESCRIPTION;
protected Client client;
protected NamedXContentRegistry xContentRegistry;
protected String index;
protected String[] sourceFields;
protected Integer docSize;

protected AbstractRetrieverTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.sourceFields = sourceFields;
this.docSize = docSize == null ? 2 : docSize;
}

protected abstract String getQueryBody(String queryText);

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String question = parameters.get(INPUT_FIELD);
try {
question = gson.fromJson(question, String.class);
} catch (Exception e) {
// throw new IllegalArgumentException("wrong input");
}
String query = getQueryBody(question);
if (StringUtils.isBlank(query)) {
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it.");
}

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) gson.toJson(contextBuilder.toString()));
} else {
listener.onResponse((T) "");
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});
client.search(searchRequest, actionListener);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
}
String question = parameters.get("input");
return question != null;
}

public void setClient(Client client) {
this.client = client;
}

protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> {
protected Client client;
protected NamedXContentRegistry xContentRegistry;

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports neural_sparse search with sparse encoding models and rank_features field.
*/
@Log4j2
@Getter
@Setter
@ToolAnnotation(NeuralSparseTool.TYPE)
public class NeuralSparseTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
private String name = TYPE;
private String modelId;
private String embeddingField;

@Builder
public NeuralSparseTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String embeddingField,
String[] sourceFields,
Integer k,
Integer docSize,
String modelId
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
}

@Override
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
);
}
return "{\"query\":{\"neural_sparse\":{\""
+ embeddingField
+ "\":{\"query_text\":\""
+ queryText
+ "\",\"model_id\":\""
+ modelId
+ "\"}}}"
+ " }";
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseTool> {
private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (NeuralSparseTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

@Override
public NeuralSparseTool create(Map<String, Object> params) {
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2;
return NeuralSparseTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.build();
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Loading

0 comments on commit e37f121

Please sign in to comment.