From 3e600a70bc467262fabe4692b3da7be98754d1c6 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 29 Jun 2024 20:03:09 -0700 Subject: [PATCH] Add Search model implementation Signed-off-by: Daniel Widdis --- .../ml/action/handler/MLSearchHandler.java | 76 ++++- .../models/SearchModelTransportAction.java | 12 +- .../SearchModelTransportActionTests.java | 303 +++++++++++------- 3 files changed, 273 insertions(+), 118 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 52be1dd608..724eeb2bb6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,6 +7,7 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -42,6 +43,9 @@ import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.SdkClientUtils; +import org.opensearch.sdk.SearchDataObjectRequest; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -76,10 +80,11 @@ public MLSearchHandler( /** * Fetch all the models from the model group index, and then create a combined query to model version index. + * @param sdkClient * @param request * @param actionListener */ - public void search(SearchRequest request, ActionListener actionListener) { + public void search(SdkClient sdkClient, SearchRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -126,10 +131,28 @@ public void search(SearchRequest request, ActionListener actionL request.source().fetchSource(rebuiltFetchSourceContext); final ActionListener doubleWrapperListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, doubleWrapperListener); - } else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { - client.search(request, doubleWrapperListener); + if (modelAccessControlHelper.skipModelAccessControl(user) + || !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .build(); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + doubleWrapperListener.onResponse(searchResponse); + } catch (Exception e) { + doubleWrapperListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + doubleWrapperListener.onFailure(e); + } + }); } else { SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); SearchRequest modelGroupSearchRequest = new SearchRequest(); @@ -148,17 +171,54 @@ public void search(SearchRequest request, ActionListener actionL Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); }); request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds)); - client.search(request, doubleWrapperListener); } else { log.debug("No model group found"); request.source().query(rewriteQueryBuilder(request.source().query(), null)); - client.search(request, doubleWrapperListener); } + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .build(); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((sr, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + doubleWrapperListener.onResponse(searchResponse); + } catch (Exception e) { + doubleWrapperListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + doubleWrapperListener.onFailure(e); + } + }); }, e -> { log.error("Fail to search model groups!", e); wrappedListener.onFailure(e); }); - client.search(modelGroupSearchRequest, modelGroupSearchActionListener); + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() + .indices(modelGroupSearchRequest.indices()) + .searchSourceBuilder(modelGroupSearchRequest.source()) + .build(); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + modelGroupSearchActionListener.onResponse(searchResponse); + } catch (Exception e) { + modelGroupSearchActionListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + modelGroupSearchActionListener.onFailure(e); + } + }); } } catch (Exception e) { log.error(e.getMessage(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java index 64d4913d5f..6402d3c44b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java @@ -14,6 +14,7 @@ import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -22,16 +23,23 @@ @Log4j2 public class SearchModelTransportAction extends HandledTransportAction { private final MLSearchHandler mlSearchHandler; + private final SdkClient sdkClient; @Inject - public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) { + public SearchModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + SdkClient sdkClient, + MLSearchHandler mlSearchHandler + ) { super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + this.sdkClient = sdkClient; this.mlSearchHandler = mlSearchHandler; } @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { request.indices(CommonValue.ML_MODEL_INDEX); - mlSearchHandler.search(request, actionListener); + mlSearchHandler.search(sdkClient, request, actionListener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 9687306241..e19d60bea5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -5,37 +5,46 @@ package org.opensearch.ml.action.models; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -44,20 +53,37 @@ import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class SearchModelTransportActionTests extends OpenSearchTestCase { + private static TestThreadPool testThreadPool = new TestThreadPool( + SearchModelTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); @Mock Client client; + SdkClient sdkClient; + @Mock NamedXContentRegistry namedXContentRegistry; @@ -67,9 +93,10 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + SearchResponse searchResponse; + @Mock ActionListener actionListener; @@ -97,163 +124,215 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); + + sdkClient = new LocalClusterIndicesClient(client, namedXContentRegistry); mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService)); - searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler); + searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, sdkClient, mlSearchHandler); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); - searchSourceBuilder.fetchSource(fetchSourceContext); - when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); Metadata metadata = mock(Metadata.class); when(metadata.hasIndex(anyString())).thenReturn(true); ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, Map.of(), 0, false); when(clusterService.state()).thenReturn(testState); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(Clusters.class), + null + ); } - public void test_DoExecute_admin() { + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + public void test_DoExecute_admin() throws InterruptedException { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(1)).search(any(), any()); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(1)).search(any()); } - public void test_DoExecute_addBackendRoles() throws IOException { + public void test_DoExecute_addBackendRoles() throws IOException, InterruptedException { SearchResponse searchResponse = createModelGroupSearchResponse(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(2)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(2)).search(any()); } - public void test_DoExecute_addBackendRoles_without_groupIds() { - SearchResponse searchResponse = mock(SearchResponse.class); - SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + public void test_DoExecute_addBackendRoles_without_groupIds() throws InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(2)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(2)).search(any()); } - public void test_DoExecute_addBackendRoles_exception() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("runtime exception")); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + public void test_DoExecute_addBackendRoles_exception() throws InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("runtime exception")); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(1)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(1)).search(any()); } - public void test_DoExecute_searchModel_before_model_creation_no_exception() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IndexNotFoundException("index not found exception")); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + public void test_DoExecute_searchModel_before_model_creation_no_exception() throws InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new IndexNotFoundException("index not found exception")); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(1)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); } - public void test_DoExecute_searchModel_before_model_creation_empty_search() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); - SearchResponseSections searchSections = new SearchResponseSections( - hits, - InternalAggregations.EMPTY, - null, - false, - false, - null, - 1 - ); - final SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 11, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + public void test_DoExecute_searchModel_before_model_creation_empty_search() throws InterruptedException { + SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(1)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(1)).search(any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); } - public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new MLResourceNotFoundException("ml resource not found exception")); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() throws InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new MLResourceNotFoundException("ml resource not found exception")); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(1)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(1)).search(any()); verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } - public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { + public void test_DoExecute_addBackendRoles_boolQuery() throws IOException, InterruptedException { SearchResponse searchResponse = createModelGroupSearchResponse(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(2)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(2)).search(any()); } - public void test_DoExecute_addBackendRoles_termQuery() throws IOException { + public void test_DoExecute_addBackendRoles_termQuery() throws IOException, InterruptedException { SearchResponse searchResponse = createModelGroupSearchResponse(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); - verify(client, times(2)).search(any(), any()); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + searchModelTransportAction.doExecute(null, searchRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + verify(mlSearchHandler).search(sdkClient, searchRequest, latchedActionListener); + verify(client, times(2)).search(any()); } private SearchResponse createModelGroupSearchResponse() throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" + " \"access\": \"public\",\n" @@ -264,7 +343,15 @@ private SearchResponse createModelGroupSearchResponse() throws IOException { + " }"; SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - return searchResponse; + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + hits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + return new SearchResponse(internalSearchResponse, null, 0, 0, 0, 1, ShardSearchFailure.EMPTY_ARRAY, mock(Clusters.class), null); } }