diff --git a/plugin/build.gradle b/plugin/build.gradle index e12d0f2cd3..6011471c95 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -66,30 +66,10 @@ dependencies { implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' implementation "org.opensearch.client:opensearch-java:2.10.2" - implementation 'software.amazon.awssdk:apache-client:2.25.50' - implementation 'software.amazon.awssdk:http-client-spi:2.25.50' - implementation 'software.amazon.awssdk:regions:2.25.50' - implementation 'software.amazon.awssdk:utils:2.25.50' - + checkstyle "com.puppycrawl.tools:checkstyle:${project.checkstyle.toolVersion}" configurations.all { - resolutionStrategy.force 'software.amazon.awssdk:apache-client:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:http-client-spi:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:regions:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:utils:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:metrics-spi:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:annotations:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:sdk-core:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:profiles:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:json-utils:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:http-auth-spi:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:http-auth-aws:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:identity-spi:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:checksums-spi:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:checksums:2.25.50' - resolutionStrategy.force 'software.amazon.awssdk:third-party-jackson-core:2.25.50' - resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4' resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5-h2:5.2.4' resolutionStrategy.force 'jakarta.json:jakarta.json-api:2.1.3' diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index e279f559e0..7cad147a7a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -130,14 +130,13 @@ private void deleteConnector(DeleteRequest deleteRequest, String connectorId, Ac if (throwable != null) { actionListener.onFailure(new RuntimeException(throwable)); } else { - context.restore(); log.info("Connector deletion result: {}, connector id: {}", r.deleted(), r.id()); DeleteResponse response = new DeleteResponse(r.shardId(), r.id(), 0, 0, 0, r.deleted()); actionListener.onResponse(response); } }); } catch (Exception e) { - log.error("Failed to delete ML connector: " + connectorId, e); + log.error("Failed to delete ML connector: {}", connectorId, e); actionListener.onFailure(e); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index fda312e1c1..6edd58b356 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -8,23 +8,18 @@ */ package org.opensearch.ml.sdkclient; -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.DELETED; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; - +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -36,9 +31,23 @@ import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.DELETED; +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import lombok.extern.log4j.Log4j2; + /** * An implementation of {@link SdkClient} that stores data in a local OpenSearch cluster using the Node Client. */ +@Log4j2 public class LocalClusterIndicesClient implements SdkClient { private final Client client; @@ -56,69 +65,60 @@ public LocalClusterIndicesClient(Client client, NamedXContentRegistry xContentRe @Override public CompletionStage putDataObjectAsync(PutDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { - client - .index( - new IndexRequest(request.index()) - .setRefreshPolicy(IMMEDIATE) - .source(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)), - ActionListener - .wrap( - r -> future - .complete(new PutDataObjectResponse.Builder().id(r.getId()).created(r.getResult() == CREATED).build()), - future::completeExceptionally - ) - ); - } catch (Exception e) { - future.completeExceptionally(e); - } - return future; + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + log.info("Indexing data object in {}", request.index()); + IndexResponse indexResponse = client + .index( + new IndexRequest(request.index()) + .setRefreshPolicy(IMMEDIATE) + .source(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)) + ) + .actionGet(); + log.info("Creation status for id {}: {}", indexResponse.getId(), indexResponse.getResult()); + return new PutDataObjectResponse.Builder().id(indexResponse.getId()).created(indexResponse.getResult() == CREATED).build(); + } catch (Exception e) { + throw new OpenSearchException(e); + } + })); } @Override public CompletionStage getDataObjectAsync(GetDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - try { - client.get(new GetRequest(request.index(), request.id()), ActionListener.wrap(r -> { - try { - XContentParser parser = jsonXContent - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()); - future.complete(new GetDataObjectResponse.Builder().id(r.getId()).parser(parser).build()); - } catch (IOException e) { - // Parsing error - future.completeExceptionally(e); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try { + log.info("Getting {} from {}", request.id(), request.index()); + GetResponse getResponse = client.get(new GetRequest(request.index(), request.id())).actionGet(); + if (!getResponse.isExists()) { + throw new OpenSearchStatusException("Data object with id " + request.id() + " not found", RestStatus.NOT_FOUND); } - }, future::completeExceptionally)); - } catch (Exception e) { - future.completeExceptionally(e); - } - return future; + XContentParser parser = jsonXContent + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()); + log.info("Retrieved data object"); + return new GetDataObjectResponse.Builder().id(getResponse.getId()).parser(parser).build(); + } catch (OpenSearchStatusException notFound) { + throw notFound; + } catch (Exception e) { + throw new OpenSearchException(e); + } + })); } @Override public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - try { - client - .delete( - new DeleteRequest(request.index(), request.id()), - ActionListener - .wrap( - r -> future - .complete( - new DeleteDataObjectResponse.Builder() - .id(r.getId()) - .shardId(r.getShardId()) - .deleted(r.getResult() == DELETED) - .build() - ), - future::completeExceptionally - ) - ); - } catch (Exception e) { - future.completeExceptionally(e); - } - return future; + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try { + log.info("Deleting {} from {}", request.id(), request.index()); + DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id())).actionGet(); + log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); + return new DeleteDataObjectResponse.Builder() + .id(deleteResponse.getId()) + .shardId(deleteResponse.getShardId()) + .deleted(deleteResponse.getResult() == DELETED) + .build(); + } catch (Exception e) { + throw new OpenSearchException(e); + } + })); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 3f24cfed73..05abbc035e 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -8,15 +8,10 @@ */ package org.opensearch.ml.sdkclient; -import static org.opensearch.client.opensearch._types.Result.Created; -import static org.opensearch.client.opensearch._types.Result.Deleted; - -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch.core.DeleteRequest; import org.opensearch.client.opensearch.core.DeleteResponse; @@ -26,6 +21,7 @@ import org.opensearch.client.opensearch.core.IndexResponse; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; @@ -36,9 +32,21 @@ import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +import static org.opensearch.client.opensearch._types.Result.Created; +import static org.opensearch.client.opensearch._types.Result.Deleted; + +import lombok.extern.log4j.Log4j2; + /** * An implementation of {@link SdkClient} that stores data in a remote OpenSearch cluster using the OpenSearch Java Client. */ +@Log4j2 public class RemoteClusterIndicesClient implements SdkClient { private OpenSearchClient openSearchClient; @@ -53,59 +61,55 @@ public RemoteClusterIndicesClient(OpenSearchClient openSearchClient) { @Override public CompletionStage putDataObjectAsync(PutDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - IndexRequest indexRequest = new IndexRequest.Builder<>().index(request.index()).document(request.dataObject()).build(); - AccessController.doPrivileged((PrivilegedAction) () -> { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { + IndexRequest indexRequest = new IndexRequest.Builder<>().index(request.index()).document(request.dataObject()).build(); + log.info("Indexing data object in {}", request.index()); IndexResponse indexResponse = openSearchClient.index(indexRequest); - future - .complete( - new PutDataObjectResponse.Builder().id(indexResponse.id()).created(indexResponse.result() == Created).build() - ); + log.info("Creation status for id {}: {}", indexResponse.id(), indexResponse.result()); + return new PutDataObjectResponse.Builder().id(indexResponse.id()).created(indexResponse.result() == Created).build(); } catch (Exception e) { - future.completeExceptionally(e); + throw new OpenSearchException("Error occurred while indexing data object", e); } - return null; - }); - return future; + })); } @Override public CompletionStage getDataObjectAsync(GetDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - GetRequest getRequest = new GetRequest.Builder().index(request.index()).id(request.id()).build(); - AccessController.doPrivileged((PrivilegedAction) () -> { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { + GetRequest getRequest = new GetRequest.Builder().index(request.index()).id(request.id()).build(); + log.info("Getting {} from {}", request.id(), request.index()); @SuppressWarnings("rawtypes") GetResponse getResponse = openSearchClient.get(getRequest, Map.class); - String source = getResponse.fields().get("_source").toJson().toString(); + if (!getResponse.found()) { + throw new OpenSearchStatusException("Data object with id " + request.id() + " not found", RestStatus.NOT_FOUND); + } + String json = new ObjectMapper().writeValueAsString(getResponse.source()); + log.info("Retrieved data object"); XContentParser parser = JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); - future.complete(new GetDataObjectResponse.Builder().id(getResponse.id()).parser(parser).build()); + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); + return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(parser).build(); + } catch (OpenSearchStatusException notFound) { + throw notFound; } catch (Exception e) { - future.completeExceptionally(e); + throw new OpenSearchException("Error occurred while getting data object", e); } - return null; - }); - return future; + })); } @Override public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request) { - CompletableFuture future = new CompletableFuture<>(); - DeleteRequest deleteRequest = new DeleteRequest.Builder().index(request.index()).id(request.id()).build(); - AccessController.doPrivileged((PrivilegedAction) () -> { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { + DeleteRequest deleteRequest = new DeleteRequest.Builder().index(request.index()).id(request.id()).build(); + log.info("Deleting {} from {}", request.id(), request.index()); DeleteResponse deleteResponse = openSearchClient.delete(deleteRequest); - future - .complete( - new DeleteDataObjectResponse.Builder().id(deleteResponse.id()).deleted(deleteResponse.result() == Deleted).build() - ); + log.info("Deletion status for id {}: {}", deleteResponse.id(), deleteResponse.result()); + return new DeleteDataObjectResponse.Builder().id(deleteResponse.id()).deleted(deleteResponse.result() == Deleted).build(); } catch (Exception e) { - future.completeExceptionally(e); + throw new OpenSearchException("Error occurred while deleting data object", e); } - return null; - }); - return future; + })); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index 867f4e291b..88075b1bfe 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -8,17 +8,18 @@ */ package org.opensearch.ml.sdkclient; +import org.apache.http.HttpHost; +import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.opensearch.OpenSearchException; +import org.opensearch.client.RestClient; +import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; -import org.opensearch.client.transport.aws.AwsSdk2Transport; -import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; +import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; import org.opensearch.core.common.Strings; import org.opensearch.sdk.SdkClient; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; +import com.fasterxml.jackson.databind.ObjectMapper; /** * A module for binding this plugin's desired implementation of {@link SdkClient}. @@ -29,7 +30,7 @@ public class SdkClientModule extends AbstractModule { public static final String REGION = "REGION"; private final String remoteMetadataEndpoint; - private final String region; + private final String region; // not using with RestClient /** * Instantiate this module using environment variables @@ -59,11 +60,21 @@ protected void configure() { } private OpenSearchClient createOpenSearchClient() { - SdkHttpClient httpClient = ApacheHttpClient.builder().build(); try { - return new OpenSearchClient( - new AwsSdk2Transport(httpClient, remoteMetadataEndpoint, Region.of(region), AwsSdk2TransportOptions.builder().build()) - ); + // Basic http(not-s) client using RestClient. + RestClient restClient = RestClient + // This HttpHost syntax works with export REMOTE_METADATA_ENDPOINT=http://127.0.0.1:9200 + .builder(HttpHost.create(remoteMetadataEndpoint)) + .setStrictDeprecationMode(true) + .setHttpClientConfigCallback(httpClientBuilder -> { + try { + return httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + } catch (Exception e) { + throw new OpenSearchException(e); + } + }) + .build(); + return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(new ObjectMapper()))); } catch (Exception e) { throw new OpenSearchException(e); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index f1009b1464..4f06a50586 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -8,22 +8,14 @@ */ package org.opensearch.ml.sdkclient; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; @@ -32,8 +24,10 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; @@ -45,6 +39,17 @@ import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class LocalClusterIndicesClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; @@ -70,20 +75,18 @@ public void setup() { public void testPutDataObject() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - - IndexResponse response = mock(IndexResponse.class); - when(response.getId()).thenReturn(TEST_ID); - when(response.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - listener.onResponse(response); - return null; - }).when(mockedClient).index(any(IndexRequest.class), any()); + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn(TEST_ID); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.index(any(IndexRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(indexResponse); PutDataObjectResponse response = sdkClient.putDataObject(putRequest); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); - verify(mockedClient, times(1)).index(requestCaptor.capture(), any()); + verify(mockedClient, times(1)).index(requestCaptor.capture()); assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); assertTrue(response.created()); @@ -99,35 +102,25 @@ public void testPutDataObject_Exception() throws IOException { }).when(mockedClient).index(any(IndexRequest.class), any()); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); - } - - public void testPutDataObject_OuterException() throws IOException { - PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); - - doThrow(new NullPointerException("test")).when(mockedClient).index(any(IndexRequest.class), any()); - - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } public void testGetDataObject() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - - GetResponse response = mock(GetResponse.class); - when(response.getId()).thenReturn(TEST_ID); - when(response.getSourceAsString()).thenReturn(testDataObject.toJson()); - listener.onResponse(response); - return null; - }).when(mockedClient).get(any(GetRequest.class), any()); + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getId()).thenReturn(TEST_ID); + when(getResponse.getSourceAsString()).thenReturn(testDataObject.toJson()); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.get(any(GetRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(getResponse); GetDataObjectResponse response = sdkClient.getDataObject(getRequest); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); - verify(mockedClient, times(1)).get(requestCaptor.capture(), any()); + verify(mockedClient, times(1)).get(requestCaptor.capture()); assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); XContentParser parser = response.parser(); @@ -135,45 +128,49 @@ public void testGetDataObject() throws IOException { assertEquals("foo", TestDataObject.parse(parser).data()); } - public void testGetDataObject_Exception() throws IOException { + public void testGetDataObject_NotFound() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).get(any(GetRequest.class), any()); + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(false); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.get(any(GetRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(getResponse); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ose.getCause().getClass()); + assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) ose.getCause()).status()); } - public void testGetDataObject_OuterException() throws IOException { + public void testGetDataObject_Exception() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doThrow(new NullPointerException("test")).when(mockedClient).get(any(GetRequest.class), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IOException("test")); + return null; + }).when(mockedClient).get(any(GetRequest.class), any()); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } public void testDeleteDataObject() throws IOException { DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - - DeleteResponse response = mock(DeleteResponse.class); - when(response.getId()).thenReturn(TEST_ID); - when(response.getResult()).thenReturn(DocWriteResponse.Result.DELETED); - listener.onResponse(response); - return null; - }).when(mockedClient).delete(any(DeleteRequest.class), any()); + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(deleteResponse.getId()).thenReturn(TEST_ID); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.delete(any(DeleteRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(deleteResponse); DeleteDataObjectResponse response = sdkClient.deleteDataObject(deleteRequest); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); - verify(mockedClient, times(1)).delete(requestCaptor.capture(), any()); + verify(mockedClient, times(1)).delete(requestCaptor.capture()); assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); } @@ -188,14 +185,6 @@ public void testDeleteDataObject_Exception() throws IOException { }).when(mockedClient).delete(any(DeleteRequest.class), any()); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); - } - - public void testDeleteDataObject_OuterException() throws IOException { - DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doThrow(new NullPointerException("test")).when(mockedClient).delete(any(DeleteRequest.class), any()); - - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index a679a7eba2..508531c2e4 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -8,19 +8,14 @@ */ package org.opensearch.ml.sdkclient; -import static org.mockito.Mockito.when; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; -import java.util.Map; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchException; -import org.opensearch.client.json.JsonData; -import org.opensearch.client.json.jackson.JacksonJsonpMapper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.Result; import org.opensearch.client.opensearch._types.ShardStatistics; @@ -30,6 +25,7 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -40,8 +36,13 @@ import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Map; + +import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RemoteClusterIndicesClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; @@ -91,17 +92,7 @@ public void testPutDataObject_Exception() throws IOException { when(mockedOpenSearchClient.index(indexRequestCaptor.capture())).thenThrow(new IOException("test")); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); - } - - public void testPutDataObject_InnerException() throws IOException { - PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); - - ArgumentCaptor> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); - when(mockedOpenSearchClient.index(indexRequestCaptor.capture())).thenReturn(null); - - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.putDataObject(putRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -110,9 +101,9 @@ public void testGetDataObject() throws IOException { GetResponse getResponse = new GetResponse.Builder<>() .index(TEST_INDEX) - .fields(Map.of("_source", JsonData.of(Map.of("data", "foo"), new JacksonJsonpMapper(new ObjectMapper())))) - .found(true) .id(TEST_ID) + .found(true) + .source(Map.of("data", "foo")) .build(); ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); @@ -129,27 +120,30 @@ public void testGetDataObject() throws IOException { } @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testGetDataObject_Exception() throws IOException { + public void testGetDataObject_NotFound() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); + GetResponse getResponse = new GetResponse.Builder<>().index(TEST_INDEX).id(TEST_ID).found(false).build(); + ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); ArgumentCaptor> mapClassCaptor = ArgumentCaptor.forClass(Class.class); - when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenThrow(new IOException("test")); + when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenReturn((GetResponse) getResponse); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ose.getCause().getClass()); + assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) ose.getCause()).status()); } @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testGetDataObject_InnerException() throws IOException { + public void testGetDataObject_Exception() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); ArgumentCaptor> mapClassCaptor = ArgumentCaptor.forClass(Class.class); - when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenReturn(null); + when(mockedOpenSearchClient.get(getRequestCaptor.capture(), mapClassCaptor.capture())).thenThrow(new IOException("test")); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.getDataObject(getRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } public void testDeleteDataObject() throws IOException { @@ -181,16 +175,6 @@ public void testDeleteDataObject_Exception() throws IOException { when(mockedOpenSearchClient.delete(deleteRequestCaptor.capture())).thenThrow(new IOException("test")); OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(IOException.class, ose.getCause().getClass()); - } - - public void testDeleteDataObject_InnerException() throws IOException { - DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - - ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); - when(mockedOpenSearchClient.delete(deleteRequestCaptor.capture())).thenReturn(null); - - OpenSearchException ose = assertThrows(OpenSearchException.class, () -> sdkClient.deleteDataObject(deleteRequest)); - assertEquals(NullPointerException.class, ose.getCause().getClass()); + assertEquals(OpenSearchException.class, ose.getCause().getClass()); } }