-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Neural sparse tool, do refactor using AbstractRetrievalTool #1686
Merged
zane-neo
merged 15 commits into
opensearch-project:feature/agent_framework_dev
from
zhichao-aws:NeuralSparseTool
Dec 4, 2023
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d462322
add abstract retriever class
zhichao-aws 52e8ea0
extends the abstract class, add neural sparse tool
zhichao-aws 0bcee44
Merge remote-tracking branch 'zhichao/feature/agent_framework_dev' in…
zhichao-aws d429420
add register logic
zhichao-aws eb0a429
merge latest
zhichao-aws 6f43ba6
tidy
zhichao-aws d9b6623
add test class
zhichao-aws 32d1170
add test,spotless Apply
zhichao-aws 741115c
fix wrong ut name
zhichao-aws e1282db
add description
zhichao-aws d345050
tidy
zhichao-aws dbc48fc
add _index and _id to retriever tool result; modify ut
zhichao-aws 12b9780
Merge branch 'feature/agent_framework_dev' into NeuralSparseTool
zhichao-aws ec598c6
tidy
zhichao-aws 6909355
remove set description from tool factory
zhichao-aws File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
154 changes: 154 additions & 0 deletions
154
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.io.IOException; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Abstract tool supports search paradigms in neural-search plugin. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
public abstract class AbstractRetrieverTool implements Tool { | ||
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; | ||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String SOURCE_FIELD = "source_field"; | ||
public static final String DOC_SIZE_FIELD = "doc_size"; | ||
|
||
protected String description = DEFAULT_DESCRIPTION; | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
protected String index; | ||
protected String[] sourceFields; | ||
protected Integer docSize; | ||
|
||
protected AbstractRetrieverTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String[] sourceFields, | ||
Integer docSize | ||
) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.index = index; | ||
this.sourceFields = sourceFields; | ||
this.docSize = docSize == null ? 2 : docSize; | ||
} | ||
|
||
protected abstract String getQueryBody(String queryText); | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
try { | ||
String question = parameters.get(INPUT_FIELD); | ||
try { | ||
question = gson.fromJson(question, String.class); | ||
} catch (Exception e) { | ||
// throw new IllegalArgumentException("wrong input"); | ||
} | ||
String query = getQueryBody(question); | ||
if (StringUtils.isBlank(query)) { | ||
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); | ||
} | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON | ||
.xContent() | ||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
searchSourceBuilder.fetchSource(sourceFields, null); | ||
searchSourceBuilder.size(docSize); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (int i = 0; i < hits.length; i++) { | ||
SearchHit hit = hits[i]; | ||
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> { | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
return gson.toJson(docContent); | ||
}); | ||
contextBuilder.append(doc).append("\n"); | ||
} | ||
listener.onResponse((T) gson.toJson(contextBuilder.toString())); | ||
} else { | ||
listener.onResponse((T) ""); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
}); | ||
client.search(searchRequest, actionListener); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
@Override | ||
public String getVersion() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
if (parameters == null || parameters.size() == 0) { | ||
return false; | ||
} | ||
String question = parameters.get("input"); | ||
return question != null; | ||
} | ||
|
||
public void setClient(Client client) { | ||
this.client = client; | ||
} | ||
|
||
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> { | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
125 changes: 125 additions & 0 deletions
125
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* This tool supports neural_sparse search with sparse encoding models and rank_features field. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
@ToolAnnotation(NeuralSparseTool.TYPE) | ||
public class NeuralSparseTool extends AbstractRetrieverTool { | ||
public static final String TYPE = "NeuralSparseTool"; | ||
public static final String MODEL_ID_FIELD = "model_id"; | ||
public static final String EMBEDDING_FIELD = "embedding_field"; | ||
private String name = TYPE; | ||
private String modelId; | ||
private String embeddingField; | ||
|
||
@Builder | ||
public NeuralSparseTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String embeddingField, | ||
String[] sourceFields, | ||
Integer k, | ||
Integer docSize, | ||
String modelId | ||
) { | ||
super(client, xContentRegistry, index, sourceFields, docSize); | ||
this.modelId = modelId; | ||
this.embeddingField = embeddingField; | ||
} | ||
|
||
@Override | ||
protected String getQueryBody(String queryText) { | ||
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { | ||
throw new IllegalArgumentException( | ||
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." | ||
); | ||
} | ||
return "{\"query\":{\"neural_sparse\":{\"" | ||
+ embeddingField | ||
+ "\":{\"query_text\":\"" | ||
+ queryText | ||
+ "\",\"model_id\":\"" | ||
+ modelId | ||
+ "\"}}}" | ||
+ " }"; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return this.name; | ||
} | ||
|
||
@Override | ||
public void setName(String s) { | ||
this.name = s; | ||
} | ||
|
||
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseTool> { | ||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (NeuralSparseTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
@Override | ||
public NeuralSparseTool create(Map<String, Object> params) { | ||
String index = (String) params.get(INDEX_FIELD); | ||
String embeddingField = (String) params.get(EMBEDDING_FIELD); | ||
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); | ||
String modelId = (String) params.get(MODEL_ID_FIELD); | ||
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; | ||
return NeuralSparseTool | ||
.builder() | ||
.client(client) | ||
.xContentRegistry(xContentRegistry) | ||
.index(index) | ||
.embeddingField(embeddingField) | ||
.sourceFields(sourceFields) | ||
.modelId(modelId) | ||
.docSize(docSize) | ||
.build(); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing the refactoring.