Skip to content

Commit

Permalink
Allow specifying executors for async threads
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed May 20, 2024
1 parent cf2f7a6 commit 18841df
Show file tree
Hide file tree
Showing 12 changed files with 383 additions and 113 deletions.
38 changes: 35 additions & 3 deletions common/src/main/java/org/opensearch/sdk/SdkClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@
import org.opensearch.OpenSearchException;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;

public interface SdkClient {

/**
* Create/Put/Index a data object/document into a table/index.
* @param request A request encapsulating the data object to store
* @param executor the executor to use for asynchronous execution
* @return A completion stage encapsulating the response or exception
*/
public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRequest request);
public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRequest request, Executor executor);

/**
* Create/Put/Index a data object/document into a table/index.
* @param request A request encapsulating the data object to store
* @return A completion stage encapsulating the response or exception
*/
default CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRequest request) {
return putDataObjectAsync(request, ForkJoinPool.commonPool());
}

/**
* Create/Put/Index a data object/document into a table/index.
Expand All @@ -41,9 +53,19 @@ default PutDataObjectResponse putDataObject(PutDataObjectRequest request) {
/**
* Read/Get a data object/document from a table/index.
* @param request A request identifying the data object to retrieve
* @param executor the executor to use for asynchronous execution
* @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception.
*/
public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRequest request);
public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRequest request, Executor executor);

/**
* Read/Get a data object/document from a table/index.
* @param request A request identifying the data object to retrieve
* @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception.
*/
default CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRequest request){
return getDataObjectAsync(request, ForkJoinPool.commonPool());
}

/**
* Read/Get a data object/document from a table/index.
Expand All @@ -62,12 +84,22 @@ default GetDataObjectResponse getDataObject(GetDataObjectRequest request) {
}
}

/**
* Delete a data object/document from a table/index.
* @param request A request identifying the data object to delete
* @param executor the executor to use for asynchronous execution
* @return A completion stage encapsulating the response or exception
*/
public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor);

/**
* Delete a data object/document from a table/index.
* @param request A request identifying the data object to delete
* @return A completion stage encapsulating the response or exception
*/
public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDataObjectRequest request);
default CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDataObjectRequest request) {
return deleteDataObjectAsync(request, ForkJoinPool.commonPool());
}

/**
* Delete a data object/document from a table/index.
Expand Down
38 changes: 20 additions & 18 deletions common/src/test/java/org/opensearch/sdk/SdkClientTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -50,17 +52,17 @@ public void setUp() {
MockitoAnnotations.openMocks(this);
sdkClient = spy(new SdkClient() {
@Override
public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRequest request) {
public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRequest request, Executor executor) {
return CompletableFuture.completedFuture(putResponse);
}

@Override
public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRequest request) {
public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRequest request, Executor executor) {
return CompletableFuture.completedFuture(getResponse);
}

@Override
public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDataObjectRequest request) {
public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) {
return CompletableFuture.completedFuture(deleteResponse);
}
});
Expand All @@ -71,94 +73,94 @@ public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDat
@Test
public void testPutDataObjectSuccess() {
assertEquals(putResponse, sdkClient.putDataObject(putRequest));
verify(sdkClient).putDataObjectAsync(putRequest);
verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class));
}

@Test
public void testPutDataObjectException() {
when(sdkClient.putDataObjectAsync(putRequest))
when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(testException));

OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.putDataObject(putRequest);
});
assertEquals(testException, exception.getCause());
assertFalse(Thread.interrupted());
verify(sdkClient).putDataObjectAsync(putRequest);
verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class));
}

@Test
public void testPutDataObjectInterrupted() {
when(sdkClient.putDataObjectAsync(putRequest))
when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(interruptedException));

OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.putDataObject(putRequest);
});
assertEquals(interruptedException, exception.getCause());
assertTrue(Thread.interrupted());
verify(sdkClient).putDataObjectAsync(putRequest);
verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class));
}

@Test
public void testGetDataObjectSuccess() {
assertEquals(getResponse, sdkClient.getDataObject(getRequest));
verify(sdkClient).getDataObjectAsync(getRequest);
verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class));
}

@Test
public void testGetDataObjectException() {
when(sdkClient.getDataObjectAsync(getRequest))
when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(testException));

OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.getDataObject(getRequest);
});
assertEquals(testException, exception.getCause());
assertFalse(Thread.interrupted());
verify(sdkClient).getDataObjectAsync(getRequest);
verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class));
}

@Test
public void testGetDataObjectInterrupted() {
when(sdkClient.getDataObjectAsync(getRequest))
when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(interruptedException));

OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.getDataObject(getRequest);
});
assertEquals(interruptedException, exception.getCause());
assertTrue(Thread.interrupted());
verify(sdkClient).getDataObjectAsync(getRequest);
verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class));
}

@Test
public void testDeleteDataObjectSuccess() {
assertEquals(deleteResponse, sdkClient.deleteDataObject(deleteRequest));
verify(sdkClient).deleteDataObjectAsync(deleteRequest);
verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class));
}

@Test
public void testDeleteDataObjectException() {
when(sdkClient.deleteDataObjectAsync(deleteRequest))
when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(testException));
OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.deleteDataObject(deleteRequest);
});
assertEquals(testException, exception.getCause());
assertFalse(Thread.interrupted());
verify(sdkClient).deleteDataObjectAsync(deleteRequest);
verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class));
}

@Test
public void testDeleteDataObjectInterrupted() {
when(sdkClient.deleteDataObjectAsync(deleteRequest))
when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)))
.thenReturn(CompletableFuture.failedFuture(interruptedException));
OpenSearchException exception = assertThrows(OpenSearchException.class, () -> {
sdkClient.deleteDataObject(deleteRequest);
});
assertEquals(interruptedException, exception.getCause());
assertTrue(Thread.interrupted());
verify(sdkClient).deleteDataObjectAsync(deleteRequest);
verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -108,7 +109,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}), () -> context.restore()));
}), context::restore));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
Expand All @@ -125,8 +126,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener<DeleteResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
sdkClient
.deleteDataObjectAsync(new DeleteDataObjectRequest.Builder().index(deleteRequest.index()).id(deleteRequest.id()).build())
.whenCompleteAsync((r, throwable) -> {
.deleteDataObjectAsync(
new DeleteDataObjectRequest.Builder().index(deleteRequest.index()).id(deleteRequest.id()).build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((r, throwable) -> {
context.restore();
if (throwable != null) {
actionListener.onFailure(new RuntimeException(throwable));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
Expand Down Expand Up @@ -72,51 +73,53 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConn
.build();
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
sdkClient.getDataObjectAsync(getDataObjectRequest).whenCompleteAsync((r, throwable) -> {
context.restore();
log.debug("Completed Get Connector Request, id:{}", connectorId);
if (throwable != null) {
Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause();
if (cause instanceof IndexNotFoundException) {
log.error("Failed to get connector index", cause);
actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND));
sdkClient
.getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
context.restore();
log.debug("Completed Get Connector Request, id:{}", connectorId);
if (throwable != null) {
Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause();
if (cause instanceof IndexNotFoundException) {
log.error("Failed to get connector index", cause);
actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND));
} else {
log.error("Failed to get ML connector " + connectorId, cause);
actionListener.onFailure(new RuntimeException(cause));
}
} else {
log.error("Failed to get ML connector " + connectorId, cause);
actionListener.onFailure(new RuntimeException(cause));
}
} else {
if (r != null && r.parser().isPresent()) {
try {
XContentParser parser = r.parser().get();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Connector mlConnector = Connector.createConnector(parser);
mlConnector.removeCredential();
if (connectorAccessControlHelper.hasPermission(user, mlConnector)) {
actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build());
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"You don't have permission to access this connector",
RestStatus.FORBIDDEN
)
);
if (r != null && r.parser().isPresent()) {
try {
XContentParser parser = r.parser().get();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Connector mlConnector = Connector.createConnector(parser);
mlConnector.removeCredential();
if (connectorAccessControlHelper.hasPermission(user, mlConnector)) {
actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build());
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"You don't have permission to access this connector",
RestStatus.FORBIDDEN
)
);
}
} catch (Exception e) {
log.error("Failed to parse ml connector" + r.id(), e);
actionListener.onFailure(e);
}
} catch (Exception e) {
log.error("Failed to parse ml connector" + r.id(), e);
actionListener.onFailure(e);
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to find connector with the provided connector id: " + connectorId,
RestStatus.NOT_FOUND
)
);
}
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to find connector with the provided connector id: " + connectorId,
RestStatus.NOT_FOUND
)
);
}
}
});
});
} catch (Exception e) {
log.error("Failed to get ML connector " + connectorId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.util.HashSet;
Expand Down Expand Up @@ -130,8 +131,11 @@ private void indexConnector(Connector connector, ActionListener<MLCreateConnecto
}
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
sdkClient
.putDataObjectAsync(new PutDataObjectRequest.Builder().index(ML_CONNECTOR_INDEX).dataObject(connector).build())
.whenCompleteAsync((r, throwable) -> {
.putDataObjectAsync(
new PutDataObjectRequest.Builder().index(ML_CONNECTOR_INDEX).dataObject(connector).build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((r, throwable) -> {
context.restore();
if (throwable != null) {
listener.onFailure(new RuntimeException(throwable));
Expand Down
Loading

0 comments on commit 18841df

Please sign in to comment.