Skip to content

Commit

Permalink
Add Search model implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jun 30, 2024
1 parent ef29a3b commit 3e600a7
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
Expand Down Expand Up @@ -42,6 +43,9 @@
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.sdk.SdkClient;
import org.opensearch.sdk.SdkClientUtils;
import org.opensearch.sdk.SearchDataObjectRequest;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
Expand Down Expand Up @@ -76,10 +80,11 @@ public MLSearchHandler(

/**
* Fetch all the models from the model group index, and then create a combined query to model version index.
* @param sdkClient
* @param request
* @param actionListener
*/
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
public void search(SdkClient sdkClient, SearchRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -126,10 +131,28 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
request.source().fetchSource(rebuiltFetchSourceContext);
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, doubleWrapperListener);
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
client.search(request, doubleWrapperListener);
if (modelAccessControlHelper.skipModelAccessControl(user)
|| !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.build();
sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
doubleWrapperListener.onResponse(searchResponse);
} catch (Exception e) {
doubleWrapperListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable);
doubleWrapperListener.onFailure(e);
}
});
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchRequest modelGroupSearchRequest = new SearchRequest();
Expand All @@ -148,17 +171,54 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });

request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
client.search(request, doubleWrapperListener);
} else {
log.debug("No model group found");
request.source().query(rewriteQueryBuilder(request.source().query(), null));
client.search(request, doubleWrapperListener);
}
SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.build();
sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((sr, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
doubleWrapperListener.onResponse(searchResponse);
} catch (Exception e) {
doubleWrapperListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable);
doubleWrapperListener.onFailure(e);
}
});
}, e -> {
log.error("Fail to search model groups!", e);
wrappedListener.onFailure(e);
});
client.search(modelGroupSearchRequest, modelGroupSearchActionListener);
SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder()
.indices(modelGroupSearchRequest.indices())
.searchSourceBuilder(modelGroupSearchRequest.source())
.build();
sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
modelGroupSearchActionListener.onResponse(searchResponse);
} catch (Exception e) {
modelGroupSearchActionListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable);
modelGroupSearchActionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.sdk.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -22,16 +23,23 @@
@Log4j2
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private final MLSearchHandler mlSearchHandler;
private final SdkClient sdkClient;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
public SearchModelTransportAction(
TransportService transportService,
ActionFilters actionFilters,
SdkClient sdkClient,
MLSearchHandler mlSearchHandler
) {
super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.sdkClient = sdkClient;
this.mlSearchHandler = mlSearchHandler;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(CommonValue.ML_MODEL_INDEX);
mlSearchHandler.search(request, actionListener);
mlSearchHandler.search(sdkClient, request, actionListener);
}
}
Loading

0 comments on commit 3e600a7

Please sign in to comment.