From 09660b8ca49e5fc41a831c16325a2ed6848704e0 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sun, 30 Jun 2024 13:31:53 -0700 Subject: [PATCH] Add Search model group implementation Signed-off-by: Daniel Widdis --- .../SearchModelGroupTransportAction.java | 39 +++++- .../SearchModelGroupTransportActionTests.java | 123 +++++++++++++++++- 2 files changed, 151 insertions(+), 11 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 4fa95fafa6..6b4aefc862 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,8 +6,10 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; @@ -18,10 +20,14 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.SdkClientUtils; +import org.opensearch.sdk.SearchDataObjectRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -30,6 +36,7 @@ @Log4j2 public class SearchModelGroupTransportAction extends HandledTransportAction { Client client; + SdkClient sdkClient; ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; @@ -39,11 +46,13 @@ public SearchModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper ) { super(MLModelGroupSearchAction.NAME, transportService, actionFilters, SearchRequest::new); this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; } @@ -63,14 +72,36 @@ private void preProcessRoleAndPerformSearch(SearchRequest request, User user, Ac final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, doubleWrappedListener); - } else { + if (!modelAccessControlHelper.skipModelAccessControl(user)) { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by " + user.getBackendRoles()); - client.search(request, doubleWrappedListener); } + SearchDataObjectRequest searchDataObjecRequest = new SearchDataObjectRequest.Builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .build(); + sdkClient + .searchDataObjectAsync(searchDataObjecRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model group search complete: {}", searchResponse.getHits().getTotalHits()); + doubleWrappedListener.onResponse(searchResponse); + } catch (Exception e) { + log.error("Failed to parse search response", e); + doubleWrappedListener + .onFailure( + new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error(e.getMessage(), e); + doubleWrappedListener.onFailure(e); + } + }); } catch (Exception e) { log.error("Failed to search", e); listener.onFailure(e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index ecfb10f221..57a8f2afd0 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -5,33 +5,73 @@ package org.opensearch.ml.action.model_group; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; +import org.opensearch.sdk.SdkClient; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { + + private static TestThreadPool testThreadPool = new TestThreadPool( + SearchModelGroupTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock Client client; + SdkClient sdkClient; + @Mock NamedXContentRegistry namedXContentRegistry; @@ -41,9 +81,15 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + SearchResponse searchResponse; + + SearchSourceBuilder searchSourceBuilder; + + @Mock + FetchSourceContext fetchSourceContext; + @Mock ActionListener actionListener; @@ -61,10 +107,13 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); + + sdkClient = new LocalClusterIndicesClient(client, namedXContentRegistry); searchModelGroupTransportAction = new SearchModelGroupTransportAction( transportService, actionFilters, client, + sdkClient, clusterService, modelAccessControlHelper ); @@ -74,21 +123,81 @@ public void setup() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + when(fetchSourceContext.includes()).thenReturn(new String[] {}); + when(fetchSourceContext.excludes()).thenReturn(new String[] {}); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(Clusters.class), + null + ); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } - public void test_DoExecute() { + public void test_DoExecute() throws InterruptedException { + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); + verify(client).search(any()); + } + + public void test_DoExecute_Exception() throws InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("search failed")); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); - verify(client).search(any(), any()); + verify(client).search(any()); + verify(actionListener).onFailure(any(RuntimeException.class)); } - public void test_skipModelAccessControlTrue() { + public void test_skipModelAccessControlTrue() throws InterruptedException { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - verify(client).search(any(), any()); + verify(client).search(any()); } public void test_ThreadContextError() {