diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index 3923a00cc1..4ec4cfa251 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -9,6 +9,10 @@ package org.opensearch.sdk; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; public class UpdateDataObjectRequest { @@ -118,6 +122,20 @@ public Builder dataObject(ToXContentObject dataObject) { this.dataObject = dataObject; return this; } + + /** + * Add a data object as a map to this builder + * @param dataObjectMap the data object as a map of fields + * @return the updated builder + */ + public Builder dataObject(Map dataObjectMap) { + this.dataObject = new ToXContentObject() { + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.map(dataObjectMap); + }}; + return this; + } /** * Builds the request diff --git a/common/src/test/java/org/opensearch/sdk/DeleteDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/DeleteDataObjectRequestTests.java index ab0c53e791..4210e17215 100644 --- a/common/src/test/java/org/opensearch/sdk/DeleteDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/DeleteDataObjectRequestTests.java @@ -16,18 +16,21 @@ public class DeleteDataObjectRequestTests { private String testIndex; private String testId; + private String testTenantId; @Before public void setUp() { testIndex = "test-index"; testId = "test-id"; + testTenantId = "test-tenant-id"; } @Test public void testDeleteDataObjectRequest() { - DeleteDataObjectRequest request = new DeleteDataObjectRequest.Builder().index(testIndex).id(testId).build(); + DeleteDataObjectRequest request = new DeleteDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).build(); assertEquals(testIndex, request.index()); assertEquals(testId, request.id()); + assertEquals(testTenantId, request.tenantId()); } } diff --git a/common/src/test/java/org/opensearch/sdk/GetDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/GetDataObjectRequestTests.java index 1e594caec3..0fc88c4468 100644 --- a/common/src/test/java/org/opensearch/sdk/GetDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/GetDataObjectRequestTests.java @@ -19,12 +19,14 @@ public class GetDataObjectRequestTests { private String testIndex; private String testId; + private String testTenantId; private FetchSourceContext testFetchSourceContext; @Before public void setUp() { testIndex = "test-index"; testId = "test-id"; + testTenantId = "test-tenant-id"; testFetchSourceContext = mock(FetchSourceContext.class); } @@ -33,11 +35,13 @@ public void testGetDataObjectRequest() { GetDataObjectRequest request = new GetDataObjectRequest.Builder() .index(testIndex) .id(testId) + .tenantId(testTenantId) .fetchSourceContext(testFetchSourceContext) .build(); assertEquals(testIndex, request.index()); assertEquals(testId, request.id()); + assertEquals(testTenantId, request.tenantId()); assertEquals(testFetchSourceContext, request.fetchSourceContext()); } } diff --git a/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java index 8903cb19c4..b1909966da 100644 --- a/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java @@ -18,19 +18,22 @@ public class PutDataObjectRequestTests { private String testIndex; + private String testTenantId; private ToXContentObject testDataObject; @Before public void setUp() { testIndex = "test-index"; + testTenantId = "test-tenant-id"; testDataObject = mock(ToXContentObject.class); } @Test public void testPutDataObjectRequest() { - PutDataObjectRequest request = new PutDataObjectRequest.Builder().index(testIndex).dataObject(testDataObject).build(); + PutDataObjectRequest request = new PutDataObjectRequest.Builder().index(testIndex).tenantId(testTenantId).dataObject(testDataObject).build(); assertEquals(testIndex, request.index()); + assertEquals(testTenantId, request.tenantId()); assertEquals(testDataObject, request.dataObject()); } } diff --git a/common/src/test/java/org/opensearch/sdk/SearchDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/SearchDataObjectRequestTests.java index 263481ec32..1d1caf8c5e 100644 --- a/common/src/test/java/org/opensearch/sdk/SearchDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/SearchDataObjectRequestTests.java @@ -18,11 +18,13 @@ public class SearchDataObjectRequestTests { private String[] testIndices; + private String testTenantId; private SearchSourceBuilder testSearchSourceBuilder; @Before public void setUp() { testIndices = new String[] {"test-index"}; + testTenantId = "test-tenant-id"; testSearchSourceBuilder = new SearchSourceBuilder(); } @@ -30,10 +32,12 @@ public void setUp() { public void testGetDataObjectRequest() { SearchDataObjectRequest request = new SearchDataObjectRequest.Builder() .indices(testIndices) + .tenantId(testTenantId) .searchSourceBuilder(testSearchSourceBuilder) .build(); assertArrayEquals(testIndices, request.indices()); + assertEquals(testTenantId, request.tenantId()); assertEquals(testSearchSourceBuilder, request.searchSourceBuilder()); } } diff --git a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java new file mode 100644 index 0000000000..f5718e05c0 --- /dev/null +++ b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +public class UpdateDataObjectRequestTests { + + private String testIndex; + private String testId; + private String testTenantId; + private ToXContentObject testDataObject; + private Map testDataObjectMap; + + @Before + public void setUp() { + testIndex = "test-index"; + testId = "test-id"; + testTenantId = "test-tenant-id"; + testDataObject = mock(ToXContentObject.class); + testDataObjectMap = Map.of("foo", "bar"); + } + + @Test + public void testUpdateDataObjectRequest() { + UpdateDataObjectRequest request = new UpdateDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObject).build(); + + assertEquals(testIndex, request.index()); + assertEquals(testId, request.id()); + assertEquals(testTenantId, request.tenantId()); + assertEquals(testDataObject, request.dataObject()); + } + + @Test + public void testUpdateDataObjectMapRequest() { + UpdateDataObjectRequest request = new UpdateDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObjectMap).build(); + + assertEquals(testIndex, request.index()); + assertEquals(testId, request.id()); + assertEquals(testTenantId, request.tenantId()); + assertEquals(testDataObjectMap, XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(XContentType.JSON, request.dataObject()), false)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 70e522ec13..04b73ae3ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -377,7 +377,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( modelAccessControlHelper .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { if (hasNewModelGroupPermission) { - mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { buildUpdateRequest( modelId, newModelGroupId, diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 4aa83cf665..f22e51b24a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -196,7 +196,7 @@ private void checkUserAccess( ) { User user = RestActionUtils.getUserContext(client); modelAccessControlHelper - .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { + .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, sdkClient, ActionListener.wrap(access -> { if (access) { doRegister(registerModelInput, listener); return; @@ -351,7 +351,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener); + mlModelManager.registerMLRemoteModel(sdkClient, registerModelInput, mlTask, listener); }, e -> { logException("Failed to register model", e, log); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 7f6f634b26..fab5ed0397 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -13,7 +13,6 @@ import java.util.HashSet; import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; @@ -257,25 +256,6 @@ public void validateUniqueModelGroupName(String name, ActionListener listener) { - GetRequest getRequest = new GetRequest(); - getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - listener.onResponse(r); - } else { - listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); - } - }, listener::onFailure)); - } - /** * Get model group from model group index. * diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 901c109da5..f1ae25deda 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -140,8 +140,10 @@ import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.script.ScriptService; import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.SdkClient; import org.opensearch.sdk.SdkClientUtils; +import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; @@ -347,11 +349,13 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput /** * + * @param sdkClient metadata client * @param mlRegisterModelInput register model input for remote models * @param mlTask ML task * @param listener action listener */ public void registerMLRemoteModel( + SdkClient sdkClient, MLRegisterModelInput mlRegisterModelInput, MLTask mlTask, ActionListener listener @@ -363,49 +367,78 @@ public void registerMLRemoteModel( mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { - if (getModelGroupResponse.isExists()) { - Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); - int updatedVersion = incrementLatestVersion(modelGroupSourceMap); - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( - modelGroupSourceMap, - modelGroupId, - getModelGroupResponse.getSeqNo(), - getModelGroupResponse.getPrimaryTerm(), - updatedVersion - ); - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - indexRemoteModel(mlRegisterModelInput, mlTask, updatedVersion + "", listener); - }, e -> { - log.error("Failed to update model group {}", modelGroupId, e); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); - listener.onFailure(e); - })); - } else { - log.error("Model group response is empty"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLValidationException("Model group not found") - ); - listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); - } - }, error -> { - if (error instanceof IndexNotFoundException) { - log.error("Model group Index is missing"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLResourceNotFoundException("Failed to get model group due to index missing") - ); - listener.onFailure(error); - } else { - log.error("Failed to get model group", error); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), error); - listener.onFailure(error); - } - })); + GetDataObjectRequest getModelGroupRequest = new GetDataObjectRequest.Builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .build(); + sdkClient + .getDataObjectAsync(getModelGroupRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse getModelGroupResponse = GetResponse.fromXContent(r.parser()); + if (getModelGroupResponse.isExists()) { + Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); + int updatedVersion = incrementLatestVersion(modelGroupSourceMap); + /* TODO UpdateDataObjectRequest needs to track response seqNo + primary term + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + modelGroupSourceMap, + modelGroupId, + getModelGroupResponse.getSeqNo(), + getModelGroupResponse.getPrimaryTerm(), + updatedVersion + ); + */ + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateDataObjectRequest updateDataObjectRequest = new UpdateDataObjectRequest.Builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + // TODO need to track these for concurrency + // .setIfSeqNo(seqNo) + // .setIfPrimaryTerm(primaryTerm) + // .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .dataObject(modelGroupSourceMap) + .build(); + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { + if (ut == null) { + indexRemoteModel(sdkClient, mlRegisterModelInput, mlTask, updatedVersion + "", listener); + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(ut); + log.error("Failed to update model group {}", modelGroupId, e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + listener.onFailure(e); + } + }); + } else { + log.error("Model group response is empty"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + new MLValidationException("Model group not found") + ); + listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); + } + } catch (Exception e) { + listener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + if (e instanceof IndexNotFoundException) { + log.error("Model group Index is missing"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + new MLResourceNotFoundException("Failed to get model group due to index missing") + ); + listener.onFailure(e); + } else { + log.error("Failed to get model group", e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + listener.onFailure(e); + } + } + }); } catch (Exception e) { log.error("Failed to register remote model", e); handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); @@ -512,6 +545,7 @@ private int incrementLatestVersion(Map modelGroupSourceMap) { } private void indexRemoteModel( + SdkClient sdkClient, MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion, @@ -550,12 +584,11 @@ private void indexRemoteModel( .tenantId(registerModelInput.getTenantId()) .build(); - IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); - if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { - indexModelMetaRequest.id(modelName); - } - indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); - indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + PutDataObjectRequest putModelMetaRequest = new PutDataObjectRequest.Builder() + .index(ML_MODEL_INDEX) + .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) + .dataObject(mlModelMeta) + .build(); // index remote model doc ActionListener indexListener = ActionListener.wrap(modelMetaRes -> { @@ -572,8 +605,22 @@ private void indexRemoteModel( handleException(functionName, taskId, e); listener.onFailure(e); }); - - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); + ThreadedActionListener putListener = threadedActionListener(REGISTER_THREAD_POOL, indexListener); + sdkClient + .putDataObjectAsync(putModelMetaRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + IndexResponse ir = IndexResponse.fromXContent(r.parser()); + putListener.onResponse(ir); + } catch (Exception e) { + putListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + putListener.onFailure(e); + } + }); }, error -> { // failed to initialize the model index log.error("Failed to init model index", error); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 3146955a2e..5129bfa16a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -311,7 +311,7 @@ public void setup() throws IOException { ) .build(); - // TODO eventually remove + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); @@ -326,7 +326,7 @@ public void setup() throws IOException { .when(modelAccessControlHelper) .validateModelGroupAccess(any(), eq("test_model_group_id"), any(), any(SdkClient.class), isA(ActionListener.class)); - // TODO eventually remove + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); @@ -359,12 +359,12 @@ public void setup() throws IOException { future.onResponse(updateResponse); when(client.update(any(UpdateRequest.class))).thenReturn(future); - // TODO eventually remove + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); + ActionListener listener = invocation.getArgument(3); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); @@ -386,10 +386,10 @@ public void setup() throws IOException { GetResponse getResponse = prepareGetResponse(modelGroup); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("updated_test_model_group_id"), isA(ActionListener.class)); } @AfterClass @@ -692,10 +692,10 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherExc @Test public void testUpdateModelWithRegisterToNewModelGroupNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("updated_test_model_group_id"), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -800,7 +800,7 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq(sdkClient), eq("mockId"), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -809,10 +809,12 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO doReturn("mockUpdateModelGroupId").when(mockUpdateModelInput).getModelGroupId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), eq(sdkClient), isA(ActionListener.class)); MLModelGroup modelGroup = MLModelGroup .builder() @@ -828,10 +830,10 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO GetResponse getResponse = prepareGetResponse(modelGroup); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("mockUpdateModelGroupId"), isA(ActionListener.class)); doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index fd511c4ff8..8cca9f7a26 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.register; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; @@ -13,6 +14,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; @@ -22,8 +25,11 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -32,6 +38,7 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; @@ -40,6 +47,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -74,12 +83,26 @@ import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; public class TransportRegisterModelActionTests extends OpenSearchTestCase { + + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportRegisterModelActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -204,10 +227,10 @@ public void setup() throws IOException { assertNotNull(transportRegisterModelAction); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -237,10 +260,16 @@ public void setup() throws IOException { when(node2.getId()).thenReturn("node2Id"); doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLModel(any(), any()); - doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLRemoteModel(any(), any(), any()); + doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLRemoteModel(any(), any(), any(), any()); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testDoExecute_LocalModelDisabledException() { @@ -279,14 +308,18 @@ public void testDoExecute_LocalModelDisabledException() { ); } - public void testDoExecute_userHasNoAccessException() { + public void testDoExecute_userHasNoAccessException() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("You don't have permissions to perform this operation on this model.", argumentCaptor.getValue().getMessage()); @@ -434,14 +467,18 @@ public void testTransportRegisterModelActionDoExecuteWithDispatchException() { verify(actionListener).onFailure(argumentCaptor.capture()); } - public void test_ValidationFailedException() { + public void test_ValidationFailedException() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); @@ -476,7 +513,7 @@ public void test_execute_registerRemoteModel_withConnectorId_success() { MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); + verify(mlModelManager).registerMLRemoteModel(eq(sdkClient), eq(input), isA(MLTask.class), eq(actionListener)); } public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConnectorId() { @@ -542,7 +579,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); + verify(mlModelManager).registerMLRemoteModel(eq(sdkClient), eq(input), isA(MLTask.class), eq(actionListener)); } public void test_execute_registerRemoteModel_withInternalConnector_connectorIsNull() { @@ -598,7 +635,7 @@ public void test_ModelNameAlreadyExists() throws IOException { verify(actionListener).onResponse(argumentCaptor.capture()); } - public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException, InterruptedException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -607,10 +644,10 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -619,7 +656,11 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException .version("1") .build(); - transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -643,7 +684,7 @@ public void test_FailureWhenSearchingModelGroupName() throws IOException { assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); } - public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { + public void test_NoAccessWhenModelNameAlreadyExists() throws IOException, InterruptedException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { @@ -653,12 +694,15 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); - transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 13aa0ab109..7fb37ad887 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -6,7 +6,6 @@ package org.opensearch.ml.model; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -469,31 +468,6 @@ public void test_ExceptionInitModelGroupIndexIfAbsent() throws InterruptedExcept assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } - // Remove when all calls to the non-sdkclient method are migrated - @Test - public void test_SuccessGetModelGroup_NoSdkClient() throws IOException { - MLModelGroup modelGroup = MLModelGroup - .builder() - .modelGroupId("testModelGroupID") - .name("test") - .description("this is test group") - .latestVersion(1) - .backendRoles(Arrays.asList("role1", "role2")) - .owner(new User()) - .access(AccessMode.PUBLIC.name()) - .build(); - - GetResponse getResponse = prepareGetResponse(modelGroup); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); - verify(modelGroupListener).onResponse(getResponse); - } - @Test public void test_SuccessGetModelGroup() throws IOException, InterruptedException { MLModelGroup modelGroup = MLModelGroup @@ -523,17 +497,19 @@ public void test_SuccessGetModelGroup() throws IOException, InterruptedException } @Test - public void test_OtherExceptionGetModelGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener - .onFailure( - new RuntimeException("Any other Exception occurred during getting the model group. Please check log for more details.") - ); - return null; - }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + public void test_OtherExceptionGetModelGroup() throws IOException, InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future + .onFailure( + new RuntimeException("Any other Exception occurred during getting the model group. Please check log for more details.") + ); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(modelGroupListener, latch); + mlModelGroupManager.getModelGroupResponse(sdkClient, "testModelGroupID", latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -543,14 +519,16 @@ public void test_OtherExceptionGetModelGroup() throws IOException { } @Test - public void test_NotFoundGetModelGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(null); - return null; - }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + public void test_NotFoundGetModelGroup() throws IOException, InterruptedException { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(null); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(modelGroupListener, latch); + mlModelGroupManager.getModelGroupResponse(sdkClient, "testModelGroupID", latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index d42fa9ca65..326721803d 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -21,12 +21,15 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH; import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; @@ -60,9 +63,12 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -70,17 +76,22 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.MemoryCircuitBreaker; @@ -107,6 +118,7 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; @@ -116,7 +128,10 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.script.ScriptService; +import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -125,12 +140,24 @@ public class MLModelManagerTests extends OpenSearchTestCase { + private static TestThreadPool testThreadPool = new TestThreadPool( + MLModelManagerTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Rule public ExpectedException expectedEx = ExpectedException.none(); private ClusterService clusterService; @Mock private Client client; + private SdkClient sdkClient; @Mock private ThreadPool threadPool; private NamedXContentRegistry xContentRegistry; @@ -183,6 +210,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { public void setup() throws URISyntaxException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; MockitoAnnotations.openMocks(this); + encryptor = new EncryptorImpl(masterKey); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build(); @@ -198,6 +226,7 @@ public void setup() throws URISyntaxException { ); clusterService = spy(new ClusterService(settings, clusterSettings, null)); xContentRegistry = NamedXContentRegistry.EMPTY; + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); modelName = "model_name1"; modelId = randomAlphaOfLength(10); @@ -251,6 +280,7 @@ public void setup() throws URISyntaxException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); modelManager = spy( new MLModelManager( @@ -309,6 +339,11 @@ public void setup() throws URISyntaxException { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + public void testRegisterMLModel_ExceedMaxRunningTask() { String error = "exceed max running task limit"; doThrow(new MLLimitExceededException(error)).when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); @@ -431,7 +466,7 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE ); } - public void testRegisterMLRemoteModel() throws PrivilegedActionException { + public void testRegisterMLRemoteModel() throws PrivilegedActionException, InterruptedException { ActionListener listener = mock(ActionListener.class); doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); @@ -441,13 +476,17 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException { MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); - doAnswer(invocation -> { - ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; - indexResponseActionListener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - when(indexResponse.getId()).thenReturn("mockIndexId"); - modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); + + IndexResponse indexResponse = new IndexResponse(new ShardId(ML_MODEL_INDEX, "_na_", 0), "mockIndexId", 1, 0, 2, true); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(indexResponse); + when(client.index(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(listener, latch); + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } @@ -460,7 +499,7 @@ public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); - modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index ebb945e60c..2cd79a8629 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -392,6 +392,29 @@ public void updateDataObjectAsync_HappyCase() { } + @Test + public void updateDataObjectAsync_HappyCaseWithMap() { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .id(TEST_ID) + .index(TEST_INDEX) + .tenantId(TENANT_ID) + .dataObject(Map.of("foo", "bar")) + .build(); + Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build()); + UpdateDataObjectResponse updateResponse = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + assertEquals(TEST_ID, updateResponse.id()); + UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); + assertEquals(TEST_ID, updateRequest.id()); + assertEquals(TEST_INDEX, updateItemRequest.tableName()); + assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); + assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + assertEquals("bar", updateItemRequest.key().get("foo").s()); + + } + @Test public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() { UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index 01fd9447c9..f402b5c102 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.EnumSet; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; @@ -307,6 +308,37 @@ public void testUpdateDataObject() throws IOException { assertEquals(1, updateActionResponse.getShardInfo().getTotal()); } + public void testUpdateDataObjectWithMap() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(Map.of("foo", "bar")) + .build(); + + UpdateResponse updateResponse = new UpdateResponse( + new ShardInfo(1, 1), + new ShardId(TEST_INDEX, "_na_", 0), + TEST_ID, + 1, + 0, + 2, + Result.UPDATED + ); + + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.update(any(UpdateRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(updateResponse); + + sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(mockedClient, times(1)).update(requestCaptor.capture()); + assertEquals(TEST_INDEX, requestCaptor.getValue().index()); + assertEquals(TEST_ID, requestCaptor.getValue().id()); + assertEquals("bar", requestCaptor.getValue().doc().sourceAsMap().get("foo")); + } + public void testUpdateDataObject_NotFound() throws IOException { UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() .index(TEST_INDEX) diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 4a2ed4cab1..467bc2c137 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -315,6 +315,34 @@ public void testUpdateDataObject() throws IOException { assertEquals(1, updateActionResponse.getShardInfo().getTotal()); } + public void testUpdateDataObjectWithMap() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(Map.of("foo", "bar")) + .build(); + + UpdateResponse> updateResponse = new UpdateResponse.Builder>() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Updated) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + @SuppressWarnings("unchecked") + ArgumentCaptor, ?>> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenReturn(updateResponse); + + sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); + + assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index()); + assertEquals(TEST_ID, updateRequestCaptor.getValue().id()); + assertEquals("bar", ((Map) updateRequestCaptor.getValue().doc()).get("foo")); + } + public void testUpdateDataObject_NotFound() throws IOException { UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() .index(TEST_INDEX)