diff --git a/common/src/main/java/org/opensearch/sdk/SdkClient.java b/common/src/main/java/org/opensearch/sdk/SdkClient.java index 13325ac54b..9fb195e13f 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClient.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClient.java @@ -11,15 +11,27 @@ import org.opensearch.OpenSearchException; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; public interface SdkClient { /** * Create/Put/Index a data object/document into a table/index. * @param request A request encapsulating the data object to store + * @param executor the executor to use for asynchronous execution * @return A completion stage encapsulating the response or exception */ - public CompletionStage putDataObjectAsync(PutDataObjectRequest request); + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor); + + /** + * Create/Put/Index a data object/document into a table/index. + * @param request A request encapsulating the data object to store + * @return A completion stage encapsulating the response or exception + */ + default CompletionStage putDataObjectAsync(PutDataObjectRequest request) { + return putDataObjectAsync(request, ForkJoinPool.commonPool()); + } /** * Create/Put/Index a data object/document into a table/index. @@ -41,9 +53,19 @@ default PutDataObjectResponse putDataObject(PutDataObjectRequest request) { /** * Read/Get a data object/document from a table/index. * @param request A request identifying the data object to retrieve + * @param executor the executor to use for asynchronous execution * @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception. */ - public CompletionStage getDataObjectAsync(GetDataObjectRequest request); + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor); + + /** + * Read/Get a data object/document from a table/index. + * @param request A request identifying the data object to retrieve + * @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception. + */ + default CompletionStage getDataObjectAsync(GetDataObjectRequest request){ + return getDataObjectAsync(request, ForkJoinPool.commonPool()); + } /** * Read/Get a data object/document from a table/index. @@ -62,12 +84,22 @@ default GetDataObjectResponse getDataObject(GetDataObjectRequest request) { } } + /** + * Delete a data object/document from a table/index. + * @param request A request identifying the data object to delete + * @param executor the executor to use for asynchronous execution + * @return A completion stage encapsulating the response or exception + */ + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor); + /** * Delete a data object/document from a table/index. * @param request A request identifying the data object to delete * @return A completion stage encapsulating the response or exception */ - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request); + default CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { + return deleteDataObjectAsync(request, ForkJoinPool.commonPool()); + } /** * Delete a data object/document from a table/index. diff --git a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java index 02f1435b75..08b3732a42 100644 --- a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java +++ b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java @@ -16,11 +16,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -50,17 +52,17 @@ public void setUp() { MockitoAnnotations.openMocks(this); sdkClient = spy(new SdkClient() { @Override - public CompletionStage putDataObjectAsync(PutDataObjectRequest request) { + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { return CompletableFuture.completedFuture(putResponse); } @Override - public CompletionStage getDataObjectAsync(GetDataObjectRequest request) { + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { return CompletableFuture.completedFuture(getResponse); } @Override - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { return CompletableFuture.completedFuture(deleteResponse); } }); @@ -71,12 +73,12 @@ public CompletionStage deleteDataObjectAsync(DeleteDat @Test public void testPutDataObjectSuccess() { assertEquals(putResponse, sdkClient.putDataObject(putRequest)); - verify(sdkClient).putDataObjectAsync(putRequest); + verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)); } @Test public void testPutDataObjectException() { - when(sdkClient.putDataObjectAsync(putRequest)) + when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { @@ -84,12 +86,12 @@ public void testPutDataObjectException() { }); assertEquals(testException, exception.getCause()); assertFalse(Thread.interrupted()); - verify(sdkClient).putDataObjectAsync(putRequest); + verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)); } @Test public void testPutDataObjectInterrupted() { - when(sdkClient.putDataObjectAsync(putRequest)) + when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(interruptedException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { @@ -97,18 +99,18 @@ public void testPutDataObjectInterrupted() { }); assertEquals(interruptedException, exception.getCause()); assertTrue(Thread.interrupted()); - verify(sdkClient).putDataObjectAsync(putRequest); + verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)); } @Test public void testGetDataObjectSuccess() { assertEquals(getResponse, sdkClient.getDataObject(getRequest)); - verify(sdkClient).getDataObjectAsync(getRequest); + verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } @Test public void testGetDataObjectException() { - when(sdkClient.getDataObjectAsync(getRequest)) + when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { @@ -116,12 +118,12 @@ public void testGetDataObjectException() { }); assertEquals(testException, exception.getCause()); assertFalse(Thread.interrupted()); - verify(sdkClient).getDataObjectAsync(getRequest); + verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } @Test public void testGetDataObjectInterrupted() { - when(sdkClient.getDataObjectAsync(getRequest)) + when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(interruptedException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { @@ -129,36 +131,36 @@ public void testGetDataObjectInterrupted() { }); assertEquals(interruptedException, exception.getCause()); assertTrue(Thread.interrupted()); - verify(sdkClient).getDataObjectAsync(getRequest); + verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } @Test public void testDeleteDataObjectSuccess() { assertEquals(deleteResponse, sdkClient.deleteDataObject(deleteRequest)); - verify(sdkClient).deleteDataObjectAsync(deleteRequest); + verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)); } @Test public void testDeleteDataObjectException() { - when(sdkClient.deleteDataObjectAsync(deleteRequest)) + when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.deleteDataObject(deleteRequest); }); assertEquals(testException, exception.getCause()); assertFalse(Thread.interrupted()); - verify(sdkClient).deleteDataObjectAsync(deleteRequest); + verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)); } @Test public void testDeleteDataObjectInterrupted() { - when(sdkClient.deleteDataObjectAsync(deleteRequest)) + when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(interruptedException)); OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.deleteDataObject(deleteRequest); }); assertEquals(interruptedException, exception.getCause()); assertTrue(Thread.interrupted()); - verify(sdkClient).deleteDataObjectAsync(deleteRequest); + verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index 6cb65279b1..73da453b5b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import java.util.ArrayList; import java.util.Arrays; @@ -108,7 +109,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener context.restore())); + }), context::restore)); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); @@ -125,8 +126,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { sdkClient - .deleteDataObjectAsync(new DeleteDataObjectRequest.Builder().index(deleteRequest.index()).id(deleteRequest.id()).build()) - .whenCompleteAsync((r, throwable) -> { + .deleteDataObjectAsync( + new DeleteDataObjectRequest.Builder().index(deleteRequest.index()).id(deleteRequest.id()).build(), + client.threadPool().executor(GENERAL_THREAD_POOL) + ) + .whenComplete((r, throwable) -> { context.restore(); if (throwable != null) { actionListener.onFailure(new RuntimeException(throwable)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index a5b71ce517..be1f477e34 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.opensearch.OpenSearchStatusException; @@ -72,51 +73,53 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - context.restore(); - log.debug("Completed Get Connector Request, id:{}", connectorId); - if (throwable != null) { - Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); - if (cause instanceof IndexNotFoundException) { - log.error("Failed to get connector index", cause); - actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + context.restore(); + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + if (cause instanceof IndexNotFoundException) { + log.error("Failed to get connector index", cause); + actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML connector " + connectorId, cause); + actionListener.onFailure(new RuntimeException(cause)); + } } else { - log.error("Failed to get ML connector " + connectorId, cause); - actionListener.onFailure(new RuntimeException(cause)); - } - } else { - if (r != null && r.parser().isPresent()) { - try { - XContentParser parser = r.parser().get(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector mlConnector = Connector.createConnector(parser); - mlConnector.removeCredential(); - if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { - actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this connector", - RestStatus.FORBIDDEN - ) - ); + if (r != null && r.parser().isPresent()) { + try { + XContentParser parser = r.parser().get(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { + actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to access this connector", + RestStatus.FORBIDDEN + ) + ); + } + } catch (Exception e) { + log.error("Failed to parse ml connector" + r.id(), e); + actionListener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to parse ml connector" + r.id(), e); - actionListener.onFailure(e); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: " + connectorId, + RestStatus.NOT_FOUND + ) + ); } - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "Failed to find connector with the provided connector id: " + connectorId, - RestStatus.NOT_FOUND - ) - ); } - } - }); + }); } catch (Exception e) { log.error("Failed to get ML connector " + connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 073a878036..0148fd2673 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.connector; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import java.util.HashSet; @@ -130,8 +131,11 @@ private void indexConnector(Connector connector, ActionListener { + .putDataObjectAsync( + new PutDataObjectRequest.Builder().index(ML_CONNECTOR_INDEX).dataObject(connector).build(), + client.threadPool().executor(GENERAL_THREAD_POOL) + ) + .whenComplete((r, throwable) -> { context.restore(); if (throwable != null) { listener.onFailure(new RuntimeException(throwable)); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 85ae6c8b86..6e3db16784 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -19,6 +19,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; @@ -65,7 +66,7 @@ public LocalClusterIndicesClient(Client client, NamedXContentRegistry xContentRe } @Override - public CompletionStage putDataObjectAsync(PutDataObjectRequest request) { + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { log.info("Indexing data object in {}", request.index()); @@ -81,11 +82,11 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe } catch (Exception e) { throw new OpenSearchException(e); } - })); + }), executor); } @Override - public CompletionStage getDataObjectAsync(GetDataObjectRequest request) { + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { log.info("Getting {} from {}", request.id(), request.index()); @@ -102,11 +103,11 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe } catch (Exception e) { throw new OpenSearchException(e); } - })); + }), executor); } @Override - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { log.info("Deleting {} from {}", request.id(), request.index()); @@ -121,6 +122,6 @@ public CompletionStage deleteDataObjectAsync(DeleteDat } catch (Exception e) { throw new OpenSearchException(e); } - })); + }), executor); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index f6012880e8..5f55a4471a 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -17,6 +17,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; import org.opensearch.OpenSearchException; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; @@ -61,7 +62,7 @@ public RemoteClusterIndicesClient(OpenSearchClient openSearchClient) { } @Override - public CompletionStage putDataObjectAsync(PutDataObjectRequest request) { + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { IndexRequest indexRequest = new IndexRequest.Builder<>().index(request.index()).document(request.dataObject()).build(); @@ -72,11 +73,11 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe } catch (Exception e) { throw new OpenSearchException("Error occurred while indexing data object", e); } - })); + }), executor); } @Override - public CompletionStage getDataObjectAsync(GetDataObjectRequest request) { + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { GetRequest getRequest = new GetRequest.Builder().index(request.index()).id(request.id()).build(); @@ -94,11 +95,11 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe } catch (Exception e) { throw new OpenSearchException(e); } - })); + }), executor); } @Override - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { DeleteRequest deleteRequest = new DeleteRequest.Builder().index(request.index()).id(request.id()).build(); @@ -118,6 +119,6 @@ public CompletionStage deleteDataObjectAsync(DeleteDat } catch (Exception e) { throw new OpenSearchException("Error occurred while deleting data object", e); } - })); + }), executor); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index e9cbc90592..7cd4b8ee3d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -13,11 +13,15 @@ import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.DELETED; import static org.opensearch.action.DocWriteResponse.Result.NOT_FOUND; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -36,6 +40,8 @@ import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -57,15 +63,25 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) // TODO: implement thread pool executors and remove this public class DeleteConnectorTransportActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportCreateConnectorActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock ThreadPool threadPool; @@ -128,6 +144,12 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testDeleteConnector_Success() throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index af98061883..ceed0dd400 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -6,14 +6,19 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -25,6 +30,8 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -41,15 +48,25 @@ import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) // TODO: implement thread pool executors and remove this public class GetConnectorTransportActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; + private static TestThreadPool testThreadPool = new TestThreadPool( + GetConnectorTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock ThreadPool threadPool; @@ -103,6 +120,12 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testGetConnector_UserHasNodeAccess() throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 638672dcff..c1cd9bcf83 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -5,12 +5,15 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; @@ -20,7 +23,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -34,6 +39,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; @@ -53,18 +60,29 @@ import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) // TODO: implement thread pool executors and remove this public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportCreateConnectorActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + private TransportCreateConnectorAction action; @Mock @@ -149,6 +167,7 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); List actions = new ArrayList<>(); actions @@ -175,6 +194,11 @@ public void setup() { when(request.getMlCreateConnectorInput()).thenReturn(input); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + public void test_execute_connectorAccessControl_notEnabled_success() throws InterruptedException { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); input.setAddAllBackendRoles(null); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index 85a1395b2c..e113c03c97 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -15,9 +15,15 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -34,6 +40,9 @@ import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -46,15 +55,26 @@ import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class LocalClusterIndicesClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; private static final String TEST_INDEX = "test_index"; + private static TestThreadPool testThreadPool = new TestThreadPool( + LocalClusterIndicesClientTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock private Client mockedClient; private SdkClient sdkClient; @@ -72,6 +92,11 @@ public void setup() { testDataObject = new TestDataObject("foo"); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + public void testPutDataObject() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); @@ -83,7 +108,10 @@ public void testPutDataObject() throws IOException { when(mockedClient.index(any(IndexRequest.class))).thenReturn(future); when(future.actionGet()).thenReturn(indexResponse); - PutDataObjectResponse response = sdkClient.putDataObject(putRequest); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); verify(mockedClient, times(1)).index(requestCaptor.capture()); @@ -101,8 +129,12 @@ public void testPutDataObject_Exception() throws IOException { return null; }).when(mockedClient).index(any(IndexRequest.class), any()); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } public void testGetDataObject() throws IOException { @@ -117,7 +149,10 @@ public void testGetDataObject() throws IOException { when(mockedClient.get(any(GetRequest.class))).thenReturn(future); when(future.actionGet()).thenReturn(getResponse); - GetDataObjectResponse response = sdkClient.getDataObject(getRequest); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); verify(mockedClient, times(1)).get(requestCaptor.capture()); @@ -139,7 +174,10 @@ public void testGetDataObject_NotFound() throws IOException { when(mockedClient.get(any(GetRequest.class))).thenReturn(future); when(future.actionGet()).thenReturn(getResponse); - GetDataObjectResponse response = sdkClient.getDataObject(getRequest); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); verify(mockedClient, times(1)).get(requestCaptor.capture()); @@ -157,8 +195,12 @@ public void testGetDataObject_Exception() throws IOException { return null; }).when(mockedClient).get(any(GetRequest.class), any()); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } public void testDeleteDataObject() throws IOException { @@ -172,7 +214,10 @@ public void testDeleteDataObject() throws IOException { future.onResponse(deleteResponse); when(mockedClient.delete(any(DeleteRequest.class))).thenReturn(future); - DeleteDataObjectResponse response = sdkClient.deleteDataObject(deleteRequest); + DeleteDataObjectResponse response = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); verify(mockedClient, times(1)).delete(requestCaptor.capture()); @@ -192,7 +237,11 @@ public void testDeleteDataObject_Exception() throws IOException { return null; }).when(mockedClient).delete(any(DeleteRequest.class), any()); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 2920225f32..239f3420eb 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -10,10 +10,16 @@ import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -28,6 +34,9 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -37,15 +46,26 @@ import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RemoteClusterIndicesClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; private static final String TEST_INDEX = "test_index"; + private static TestThreadPool testThreadPool = new TestThreadPool( + RemoteClusterIndicesClientTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock private OpenSearchClient mockedOpenSearchClient; private SdkClient sdkClient; @@ -60,6 +80,11 @@ public void setup() { testDataObject = new TestDataObject("foo"); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + public void testPutDataObject() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); @@ -76,21 +101,54 @@ public void testPutDataObject() throws IOException { ArgumentCaptor> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); when(mockedOpenSearchClient.index(indexRequestCaptor.capture())).thenReturn(indexResponse); - PutDataObjectResponse response = sdkClient.putDataObject(putRequest); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); assertEquals(TEST_INDEX, indexRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); assertTrue(response.created()); } + public void testPutDataObject_Updated() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); + + IndexResponse indexResponse = new IndexResponse.Builder() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Updated) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + ArgumentCaptor> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + when(mockedOpenSearchClient.index(indexRequestCaptor.capture())).thenReturn(indexResponse); + + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, indexRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertFalse(response.created()); + } + public void testPutDataObject_Exception() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); ArgumentCaptor> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); when(mockedOpenSearchClient.index(indexRequestCaptor.capture())).thenThrow(new IOException("test")); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -108,7 +166,10 @@ public void testGetDataObject() throws IOException { ArgumentCaptor> mapClassCaptor = ArgumentCaptor.forClass(Class.class); when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenReturn((GetResponse) getResponse); - GetDataObjectResponse response = sdkClient.getDataObject(getRequest); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); assertEquals(TEST_INDEX, getRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); @@ -128,7 +189,10 @@ public void testGetDataObject_NotFound() throws IOException { ArgumentCaptor> mapClassCaptor = ArgumentCaptor.forClass(Class.class); when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenReturn((GetResponse) getResponse); - GetDataObjectResponse response = sdkClient.getDataObject(getRequest); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); assertEquals(TEST_INDEX, getRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); @@ -143,8 +207,12 @@ public void testGetDataObject_Exception() throws IOException { ArgumentCaptor> mapClassCaptor = ArgumentCaptor.forClass(Class.class); when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenThrow(new IOException("test")); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } public void testDeleteDataObject() throws IOException { @@ -163,13 +231,46 @@ public void testDeleteDataObject() throws IOException { ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); when(mockedOpenSearchClient.delete(deleteRequestCaptor.capture())).thenReturn(deleteResponse); - DeleteDataObjectResponse response = sdkClient.deleteDataObject(deleteRequest); + DeleteDataObjectResponse response = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, deleteRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertEquals(2, response.shardInfo().getTotal()); + assertEquals(2, response.shardInfo().getSuccessful()); + assertEquals(0, response.shardInfo().getFailed()); + assertTrue(response.deleted()); + } + + public void testDeleteDataObject_NotFound() throws IOException { + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); + + DeleteResponse deleteResponse = new DeleteResponse.Builder() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.NotFound) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(2).total(2).build()) + .version(0) + .build(); + + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); + when(mockedOpenSearchClient.delete(deleteRequestCaptor.capture())).thenReturn(deleteResponse); + + DeleteDataObjectResponse response = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); assertEquals(TEST_INDEX, deleteRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); assertEquals(2, response.shardInfo().getTotal()); assertEquals(2, response.shardInfo().getSuccessful()); assertEquals(0, response.shardInfo().getFailed()); + assertFalse(response.deleted()); } public void testDeleteDataObject_Exception() throws IOException { @@ -178,7 +279,11 @@ public void testDeleteDataObject_Exception() throws IOException { ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); when(mockedOpenSearchClient.delete(deleteRequestCaptor.capture())).thenThrow(new IOException("test")); - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(OpenSearchException.class, ose.getCause().getClass()); + CompletableFuture future = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); } }