diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index f0d6f08127..8d98f86b6a 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -323,8 +323,21 @@ default ActionFuture deleteConnector(String connectorId) { return actionFuture; } + /** + * Delete connector for remote model + * @param connectorId The id of the connector to delete + * @return the result future + */ + default ActionFuture deleteConnector(String connectorId, String tenantId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deleteConnector(connectorId, tenantId, actionFuture); + return actionFuture; + } + void deleteConnector(String connectorId, ActionListener listener); + void deleteConnector(String connectorId, String tenantId, ActionListener listener); + /** * Register model group * For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 5819d055e7..288a9f2e3a 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -146,7 +146,7 @@ public void run(MLInput mlInput, Map args, ActionListener getMlGetModelResponseActionListener(A ActionListener internalListener = ActionListener.wrap(predictionResponse -> { listener.onResponse(predictionResponse.getMlModel()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, res -> { - MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res); - return getResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLModelGetResponse::fromActionResponse); } @Override public void deleteModel(String modelId, ActionListener listener) { MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); - client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> { - listener.onResponse(deleteResponse); - }, listener::onFailure)); + client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @Override public void searchModel(SearchRequest searchRequest, ActionListener listener) { - client - .execute( - MLModelSearchAction.INSTANCE, - searchRequest, - ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure) - ); + client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @Override @@ -238,19 +227,12 @@ public void getTask(String taskId, ActionListener listener) { public void deleteTask(String taskId, ActionListener listener) { MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build(); - client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> { - listener.onResponse(deleteResponse); - }, listener::onFailure)); + client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @Override public void searchTask(SearchRequest searchRequest, ActionListener listener) { - client - .execute( - MLTaskSearchAction.INSTANCE, - searchRequest, - ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure) - ); + client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @Override @@ -280,9 +262,23 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio @Override public void deleteConnector(String connectorId, ActionListener listener) { MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId); - client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> { - listener.onResponse(deleteResponse); - }, listener::onFailure)); + client + .execute( + MLConnectorDeleteAction.INSTANCE, + connectorDeleteRequest, + ActionListener.wrap(listener::onResponse, listener::onFailure) + ); + } + + @Override + public void deleteConnector(String connectorId, String tenantId, ActionListener listener) { + MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId); + client + .execute( + MLConnectorDeleteAction.INSTANCE, + connectorDeleteRequest, + ActionListener.wrap(listener::onResponse, listener::onFailure) + ); } @Override @@ -294,9 +290,7 @@ public void registerAgent(MLAgent mlAgent, ActionListener listener) { MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId); - client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> { - listener.onResponse(deleteResponse); - }, listener::onFailure)); + client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @Override @@ -324,123 +318,78 @@ private ActionListener getMlListToolsResponseActionListener ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, res -> { - MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res); - return getResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLToolsListResponse::fromActionResponse); } private ActionListener getMlGetToolResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(mlModelGetResponse -> { listener.onResponse(mlModelGetResponse.getToolMetadata()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, res -> { - MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res); - return getResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLToolGetResponse::fromActionResponse); } private ActionListener getMlGetConfigResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(mlConfigGetResponse -> { listener.onResponse(mlConfigGetResponse.getMlConfig()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, res -> { - MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res); - return getResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLConfigGetResponse::fromActionResponse); } private ActionListener getMLRegisterAgentResponseActionListener( ActionListener listener ) { - ActionListener actionListener = wrapActionListener(listener, res -> { - MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res); - return mlRegisterAgentResponse; - }); - return actionListener; + return wrapActionListener(listener, MLRegisterAgentResponse::fromActionResponse); } private ActionListener getMLTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener .wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, response -> { - MLTaskGetResponse getResponse = MLTaskGetResponse.fromActionResponse(response); - return getResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLTaskGetResponse::fromActionResponse); } private ActionListener getMlDeployModelResponseActionListener(ActionListener listener) { - ActionListener actionListener = wrapActionListener(listener, response -> { - MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response); - return deployModelResponse; - }); - return actionListener; + return wrapActionListener(listener, MLDeployModelResponse::fromActionResponse); } private ActionListener getMlUndeployModelsResponseActionListener( ActionListener listener ) { - ActionListener actionListener = wrapActionListener(listener, response -> { - MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response); - return deployModelResponse; - }); - return actionListener; + return wrapActionListener(listener, MLUndeployModelsResponse::fromActionResponse); } private ActionListener getMlCreateConnectorResponseActionListener( ActionListener listener ) { - ActionListener actionListener = wrapActionListener(listener, response -> { - MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response); - return createConnectorResponse; - }); - return actionListener; + return wrapActionListener(listener, MLCreateConnectorResponse::fromActionResponse); } private ActionListener getMlRegisterModelGroupResponseActionListener( ActionListener listener ) { - ActionListener actionListener = wrapActionListener(listener, response -> { - MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response); - return registerModelGroupResponse; - }); - return actionListener; + return wrapActionListener(listener, MLRegisterModelGroupResponse::fromActionResponse); } private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(predictionResponse -> { listener.onResponse(predictionResponse.getOutput()); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, res -> { - MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); - return predictionResponse; - }); - return actionListener; + return wrapActionListener(internalListener, MLTaskResponse::fromActionResponse); } private ActionListener getMLRegisterModelResponseActionListener( ActionListener listener ) { - ActionListener actionListener = wrapActionListener(listener, res -> { - MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res); - return registerModelResponse; - }); - return actionListener; + return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse); } private ActionListener wrapActionListener( final ActionListener listener, final Function recreate ) { - ActionListener actionListener = ActionListener.wrap(r -> { + return ActionListener.wrap(r -> { listener.onResponse(recreate.apply(r)); ; - }, e -> { listener.onFailure(e); }); - return actionListener; + }, listener::onFailure); } private void validateMLInput(MLInput mlInput, boolean requireInput) { diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 5100e7cb19..b0bdb80db8 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -216,6 +216,11 @@ public void execute(FunctionName name, Input input, ActionListener listener) { + listener.onResponse(deleteResponse); + } + @Override public void deleteConnector(String connectorId, ActionListener listener) { listener.onResponse(deleteResponse); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 6fea6e8c60..0f4904e20c 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -17,6 +17,7 @@ import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.input.Constants.ACTION; import static org.opensearch.ml.common.input.Constants.ALGORITHM; +import static org.opensearch.ml.common.input.Constants.ASYNC; import static org.opensearch.ml.common.input.Constants.KMEANS; import static org.opensearch.ml.common.input.Constants.MODELID; import static org.opensearch.ml.common.input.Constants.PREDICT; @@ -251,6 +252,42 @@ public void predict() { assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult()); } + @Test + public void execute_train_asyncTask() { + String modelId = "test_model_id"; + String status = "InProgress"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build(); + actionListener.onResponse(MLTaskResponse.builder().output(output).build()); + return null; + }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLOutput.class); + Map args = new HashMap<>(); + args.put(ACTION, TRAIN); + args.put(ALGORITHM, KMEANS); + args.put(ASYNC, true); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build(); + machineLearningNodeClient.run(mlInput, args, trainingActionListener); + + verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); + verify(trainingActionListener).onResponse(argumentCaptor.capture()); + assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId()); + assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus()); + } + + @Test + public void execute_predict_missing_modelId() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The model ID is required for prediction."); + Map args = new HashMap<>(); + args.put(ACTION, PREDICT); + args.put(ALGORITHM, KMEANS); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build(); + machineLearningNodeClient.run(mlInput, args, dataFrameActionListener); + } + @Test public void predict_Exception_WithNullAlgorithm() { exceptionRule.expect(IllegalArgumentException.class); @@ -288,6 +325,27 @@ public void train() { assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus()); } + @Test + public void registerModelGroup_withValidInput() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse("groupId", "created"); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), any(), any()); + + MLRegisterModelGroupInput input = MLRegisterModelGroupInput + .builder() + .name("test") + .description("description") + .backendRoles(Arrays.asList("role1", "role2")) + .modelAccessMode(AccessMode.PUBLIC) + .build(); + + machineLearningNodeClient.registerModelGroup(input, registerModelGroupResponseActionListener); + verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any()); + } + @Test public void train_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); @@ -499,6 +557,26 @@ public void getModel() { assertEquals(modelContent, argumentCaptor.getValue().getContent()); } + @Test + public void deleteConnector_withTenantId() { + String connectorId = "connectorId"; + String tenantId = "tenantId"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, connectorId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLConnectorDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + machineLearningNodeClient.deleteConnector(connectorId, tenantId, deleteConnectorActionListener); + + verify(client).execute(eq(MLConnectorDeleteAction.INSTANCE), isA(MLConnectorDeleteRequest.class), any()); + verify(deleteConnectorActionListener).onResponse(argumentCaptor.capture()); + assertEquals(connectorId, (argumentCaptor.getValue()).getId()); + } + @Test public void deleteModel() { String modelId = "testModelId"; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index fb5badf2ea..595e7fcee6 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -41,7 +41,8 @@ public AwsConnector( List backendRoles, AccessMode accessMode, User owner, - ConnectorClientConfig connectorClientConfig + ConnectorClientConfig connectorClientConfig, + String tenantId ) { super( name, @@ -54,7 +55,8 @@ public AwsConnector( backendRoles, accessMode, owner, - connectorClientConfig + connectorClientConfig, + tenantId ); validate(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index edf26b954d..c7d23bc8f5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.connector; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -24,6 +26,7 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; @@ -63,7 +66,8 @@ public HttpConnector( List backendRoles, AccessMode accessMode, User owner, - ConnectorClientConfig connectorClientConfig + ConnectorClientConfig connectorClientConfig, + String tenantId ) { validateProtocol(protocol); this.name = name; @@ -77,6 +81,7 @@ public HttpConnector( this.access = accessMode; this.owner = owner; this.connectorClientConfig = connectorClientConfig; + this.tenantId = tenantId; } @@ -138,6 +143,9 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException case CLIENT_CONFIG_FIELD: connectorClientConfig = ConnectorClientConfig.parse(parser); break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -187,6 +195,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (connectorClientConfig != null) { builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -202,6 +213,7 @@ public HttpConnector(StreamInput input) throws IOException { } private void parseFromStream(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); this.name = input.readOptionalString(); this.version = input.readOptionalString(); this.description = input.readOptionalString(); @@ -230,10 +242,14 @@ private void parseFromStream(StreamInput input) throws IOException { if (input.readBoolean()) { this.connectorClientConfig = new ConnectorClientConfig(input); } + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = input.readOptionalString(); + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(protocol); out.writeOptionalString(name); out.writeOptionalString(version); @@ -280,6 +296,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -296,10 +315,10 @@ public void update(MLCreateConnectorInput updateContent, Function 0) { + if (updateContent.getParameters() != null && !updateContent.getParameters().isEmpty()) { getParameters().putAll(updateContent.getParameters()); } - if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) { + if (updateContent.getCredential() != null && !updateContent.getCredential().isEmpty()) { this.credential = updateContent.getCredential(); encrypt(function); } @@ -367,7 +386,7 @@ public void decrypt(String action, Function function) { } this.decryptedCredential = decrypted; Optional connectorAction = findAction(action); - Map headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null; + Map headers = connectorAction.map(ConnectorAction::getHeaders).orElse(null); this.decryptedHeaders = createDecryptedHeaders(headers); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java index a1e3a6391e..8f1f1f146d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.connector; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -22,24 +24,41 @@ import lombok.Builder; import lombok.Getter; +@Getter public class MLConnectorDeleteRequest extends ActionRequest { - @Getter - String connectorId; + private final String connectorId; + private final String tenantId; @Builder + public MLConnectorDeleteRequest(String connectorId, String tenantId) { + this.connectorId = connectorId; + this.tenantId = tenantId; + } + public MLConnectorDeleteRequest(String connectorId) { this.connectorId = connectorId; + this.tenantId = null; } public MLConnectorDeleteRequest(StreamInput input) throws IOException { super(input); + Version streamInputVersion = input.getVersion(); this.connectorId = input.readString(); + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = input.readOptionalString(); + } else { + this.tenantId = null; + } } @Override public void writeTo(StreamOutput output) throws IOException { + Version streamOutputVersion = output.getVersion(); super.writeTo(output); output.writeString(connectorId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java index dbd7c9b42c..3265db4ff2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java @@ -35,6 +35,10 @@ public MLConnectorGetResponse(StreamInput in) throws IOException { mlConnector = Connector.fromStream(in); } + public Connector getMlConnector() { + return mlConnector; + } + @Override public void writeTo(StreamOutput out) throws IOException { mlConnector.writeTo(out); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 7029fccb7e..99dc51ab99 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.connector; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; @@ -30,6 +32,7 @@ import lombok.Builder; import lombok.Data; +import lombok.Setter; @Data public class MLCreateConnectorInput implements ToXContentObject, Writeable { @@ -56,6 +59,8 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private String description; private String version; private String protocol; + @Setter + private String tenantId; private Map parameters; private Map credential; private List actions; @@ -80,7 +85,8 @@ public MLCreateConnectorInput( AccessMode access, boolean dryRun, boolean updateConnector, - ConnectorClientConfig connectorClientConfig + ConnectorClientConfig connectorClientConfig, + String tenantId ) { if (!dryRun && !updateConnector) { @@ -110,6 +116,7 @@ public MLCreateConnectorInput( this.dryRun = dryRun; this.updateConnector = updateConnector; this.connectorClientConfig = connectorClientConfig; + this.tenantId = tenantId; } @@ -130,6 +137,7 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update AccessMode access = null; boolean dryRun = false; ConnectorClientConfig connectorClientConfig = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -181,6 +189,9 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update case AbstractConnector.CLIENT_CONFIG_FIELD: connectorClientConfig = ConnectorClientConfig.parse(parser); break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -199,7 +210,8 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update access, dryRun, updateConnector, - connectorClientConfig + connectorClientConfig, + tenantId ); } @@ -239,6 +251,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (connectorClientConfig != null) { builder.field(AbstractConnector.CLIENT_CONFIG_FIELD, connectorClientConfig); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -294,6 +309,9 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } } public MLCreateConnectorInput(StreamInput input) throws IOException { @@ -329,6 +347,9 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { this.connectorClientConfig = new ConnectorClientConfig(input); } } + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = input.readOptionalString(); + } } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java index b6ad6b054f..c85841184c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java @@ -18,6 +18,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLConnectorDeleteRequestTests { @@ -94,4 +95,27 @@ public void fromActionRequestWithConnectorDeleteRequestSuccess() { assertSame(mlConnectorDeleteRequest, mlConnectorDeleteRequestFromActionRequest); assertEquals(mlConnectorDeleteRequest.getConnectorId(), mlConnectorDeleteRequestFromActionRequest.getConnectorId()); } + + @Test + public void testConstructorWithTenantId() { + String tenantId = "test_tenant"; + MLConnectorDeleteRequest request = MLConnectorDeleteRequest.builder().connectorId(connectorId).tenantId(tenantId).build(); + + assertEquals(connectorId, request.getConnectorId()); + assertEquals(tenantId, request.getTenantId()); + } + + @Test + public void testWriteToWithTenantId() throws IOException { + String tenantId = "test_tenant"; + MLConnectorDeleteRequest request = MLConnectorDeleteRequest.builder().connectorId(connectorId).tenantId(tenantId).build(); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + + StreamInput input = output.bytes().streamInput(); + MLConnectorDeleteRequest parsedRequest = new MLConnectorDeleteRequest(input); + + assertEquals(connectorId, parsedRequest.getConnectorId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 936c71da95..1dee760f4f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -9,6 +9,7 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import java.io.IOException; @@ -108,4 +109,10 @@ public void writeTo(StreamOutput out) throws IOException { }; MLConnectorGetResponse.fromActionResponse(actionResponse); } + + @Test + public void testNullConnector() { + MLConnectorGetResponse response = MLConnectorGetResponse.builder().build(); + assertNull(response.getMlConnector()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index c2bf0b77b0..17b1673694 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -313,6 +313,36 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException assertEquals("Connector name is null", exception.getMessage()); } + @Test + public void testParseWithTenantId() throws Exception { + String inputWithTenantId = + "{\"name\":\"test_connector_name\",\"credential\":{\"key\":\"test_key_value\"},\"version\":\"1\",\"protocol\":\"http\",\"tenant_id\":\"test_tenant\"}"; + testParseFromJsonString(inputWithTenantId, parsedInput -> { + assertEquals("test_connector_name", parsedInput.getName()); + assertEquals("test_tenant", parsedInput.getTenantId()); + }); + } + + @Test + public void testParseWithUnknownFields() throws Exception { + String inputWithUnknownFields = + "{\"name\":\"test_connector_name\",\"credential\":{\"key\":\"test_key_value\"},\"version\":\"1\",\"protocol\":\"http\",\"unknown_field\":\"unknown_value\"}"; + testParseFromJsonString(inputWithUnknownFields, parsedInput -> { + assertEquals("test_connector_name", parsedInput.getName()); + assertNull(parsedInput.getTenantId()); + }); + } + + @Test + public void testParseWithEmptyActions() throws Exception { + String inputWithEmptyActions = + "{\"name\":\"test_connector_name\",\"credential\":{\"key\":\"test_key_value\"},\"version\":\"1\",\"protocol\":\"http\",\"actions\":[]}"; + testParseFromJsonString(inputWithEmptyActions, parsedInput -> { + assertEquals("test_connector_name", parsedInput.getName()); + assertTrue(parsedInput.getActions().isEmpty()); + }); + } + @Test public void testWriteToVersionCompatibility() throws IOException { MLCreateConnectorInput input = mlCreateConnectorInput; // Assuming mlCreateConnectorInput is already initialized diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MultiTenantPredictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MultiTenantPredictable.java new file mode 100644 index 0000000000..bba5ae904a --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MultiTenantPredictable.java @@ -0,0 +1,18 @@ +package org.opensearch.ml.engine; + +import java.util.Map; + +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.engine.encryptor.Encryptor; + +public interface MultiTenantPredictable extends Predictable { + + /** + * Init model (load model into memory) with ML model content and params. + * @param model ML model + * @param params other parameters + * @param encryptor encryptor + * @param tenantId tenantId + */ + void initModel(MLModel model, Map params, Encryptor encryptor, String tenantId); +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index bb2799941a..e9552cc650 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -5,20 +5,21 @@ package org.opensearch.ml.action.connector; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -34,6 +35,13 @@ import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.DeleteDataObjectRequest; +import org.opensearch.remote.metadata.client.DeleteDataObjectResponse; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; @@ -44,104 +52,175 @@ @Log4j2 public class DeleteConnectorTransportAction extends HandledTransportAction { - Client client; - NamedXContentRegistry xContentRegistry; - - ConnectorAccessControlHelper connectorAccessControlHelper; + private final Client client; + private final SdkClient sdkClient; + private final NamedXContentRegistry xContentRegistry; + private final ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public DeleteConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, NamedXContentRegistry xContentRegistry, - ConnectorAccessControlHelper connectorAccessControlHelper + ConnectorAccessControlHelper connectorAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLConnectorDeleteAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); this.client = client; + this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.fromActionRequest(request); String connectorId = mlConnectorDeleteRequest.getConnectorId(); - DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId).setRefreshPolicy(IMMEDIATE); - connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(x -> { - if (Boolean.TRUE.equals(x)) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); - searchRequest.source(sourceBuilder); - client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> { - SearchHit[] searchHits = searchResponse.getHits().getHits(); - if (searchHits.length == 0) { - deleteConnector(deleteRequest, connectorId, actionListener); - } else { - log - .error( - searchHits.length + " models are still using this connector, please delete or update the models first!" - ); - List modelIds = new ArrayList<>(); - for (SearchHit hit : searchHits) { - modelIds.add(hit.getId()); + String tenantId = mlConnectorDeleteRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + connectorAccessControlHelper + .validateConnectorAccess( + sdkClient, + client, + connectorId, + tenantId, + mlFeatureEnabledSetting, + ActionListener + .wrap( + isAllowed -> handleConnectorAccessValidation(connectorId, tenantId, isAllowed, actionListener), + e -> handleConnectorAccessValidationFailure(connectorId, e, actionListener) + ) + ); + } + + private void handleConnectorAccessValidation( + String connectorId, + String tenantId, + boolean isAllowed, + ActionListener actionListener + ) { + if (isAllowed) { + checkForModelsUsingConnector(connectorId, tenantId, actionListener); + } else { + actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector")); + } + } + + private void handleConnectorAccessValidationFailure(String connectorId, Exception e, ActionListener actionListener) { + log.error("Failed to delete ML connector: {}", connectorId, e); + actionListener.onFailure(e); + } + + private void checkForModelsUsingConnector(String connectorId, String tenantId, ActionListener actionListener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener restoringListener = ActionListener.runBefore(actionListener, context::restore); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + sourceBuilder.query(QueryBuilders.matchQuery(TENANT_ID_FIELD, tenantId)); + } + + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(ML_MODEL_INDEX) + .tenantId(tenantId) + .searchSourceBuilder(sourceBuilder) + .build(); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((sr, st) -> { + if (sr != null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + deleteConnector(connectorId, tenantId, restoringListener); + } else { + handleModelsUsingConnector(searchHits, connectorId, restoringListener); } - actionListener + } catch (Exception e) { + log.error("Failed to parse search response", e); + restoringListener .onFailure( - new OpenSearchStatusException( - searchHits.length - + " models are still using this connector, please delete or update the models first: " - + Arrays.toString(modelIds.toArray(new String[0])), - RestStatus.CONFLICT - ) + new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR) ); } - }, e -> { - if (e instanceof IndexNotFoundException) { - deleteConnector(deleteRequest, connectorId, actionListener); - return; - } - log.error("Failed to delete ML connector: " + connectorId, e); - actionListener.onFailure(e); - }), () -> context.restore())); - } catch (Exception e) { - log.error(e.getMessage(), e); - actionListener.onFailure(e); - } - } else { - actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector")); - } - }, e -> { - log.error("Failed to delete ML connector: " + connectorId, e); + } else { + Exception cause = SdkClientUtils.unwrapAndConvertToException(st); + handleSearchFailure(connectorId, tenantId, cause, restoringListener); + } + }); + } catch (Exception e) { + log.error("Failed to check for models using connector: {}", connectorId, e); actionListener.onFailure(e); - })); + } } - private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener actionListener) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.delete(deleteRequest, ActionListener.runBefore(new ActionListener<>() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { - log.info("Connector id:{} not found", connectorId); - actionListener.onResponse(deleteResponse); - return; - } - log.info("Completed Delete Connector Request, connector id:{} deleted", connectorId); - actionListener.onResponse(deleteResponse); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML connector: " + connectorId, e); - actionListener.onFailure(e); - } - }, () -> context.restore())); + private void handleModelsUsingConnector(SearchHit[] searchHits, String connectorId, ActionListener actionListener) { + log.error("{} models are still using this connector, please delete or update the models first!", searchHits.length); + List modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } + actionListener + .onFailure( + new OpenSearchStatusException( + searchHits.length + + " models are still using this connector, please delete or update the models first: " + + Arrays.toString(modelIds.toArray(new String[0])), + RestStatus.CONFLICT + ) + ); + } + + private void handleSearchFailure(String connectorId, String tenantId, Exception cause, ActionListener actionListener) { + if (cause instanceof IndexNotFoundException) { + deleteConnector(connectorId, tenantId, actionListener); + return; + } + log.error("Failed to search for models using connector: {}", connectorId, cause); + actionListener.onFailure(cause); + } + + private void deleteConnector(String connectorId, String tenantId, ActionListener actionListener) { + DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); + try { + sdkClient + .deleteDataObjectAsync( + DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build(), + client.threadPool().executor(GENERAL_THREAD_POOL) + ) + .whenComplete((response, throwable) -> handleDeleteResponse(response, throwable, connectorId, actionListener)); } catch (Exception e) { - log.error("Failed to delete ML connector: " + connectorId, e); + log.error("Failed to delete ML connector: {}", connectorId, e); actionListener.onFailure(e); } } + + private void handleDeleteResponse( + DeleteDataObjectResponse response, + Throwable throwable, + String connectorId, + ActionListener actionListener + ) { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to delete ML connector: {}", connectorId, cause); + actionListener.onFailure(cause); + } else { + try { + DeleteResponse deleteResponse = DeleteResponse.fromXContent(response.parser()); + log.info("Connector deletion result: {}, connector id: {}", deleteResponse.getResult(), response.id()); + actionListener.onResponse(deleteResponse); + } catch (IOException e) { + actionListener.onFailure(e); + } + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index 12611a942d..8620e79d6e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -5,16 +5,11 @@ package org.opensearch.ml.action.connector; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; -import java.util.Objects; - import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -23,15 +18,16 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -45,22 +41,26 @@ public class GetConnectorTransportAction extends HandledTransportAction { Client client; - NamedXContentRegistry xContentRegistry; + SdkClient sdkClient; ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Inject public GetConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - NamedXContentRegistry xContentRegistry, - ConnectorAccessControlHelper connectorAccessControlHelper + SdkClient sdkClient, + ConnectorAccessControlHelper connectorAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLConnectorGetAction.NAME, transportService, actionFilters, MLConnectorGetRequest::new); this.client = client; - this.xContentRegistry = xContentRegistry; + this.sdkClient = sdkClient; this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -68,64 +68,60 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.debug("Completed Get Connector Request, id:{}", connectorId); - - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector mlConnector = Connector.createConnector(parser); - mlConnector.removeCredential(); - if (!Objects.equals(tenantId, mlConnector.getTenantId())) { - actionListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this connector", - RestStatus.FORBIDDEN - ) - ); - } - if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { - actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this connector", - RestStatus.FORBIDDEN - ) - ); - } - } catch (Exception e) { - log.error("Failed to parse ml connector" + r.getId(), e); - actionListener.onFailure(e); - } - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "Failed to find connector with the provided connector id: " + connectorId, - RestStatus.NOT_FOUND - ) - ); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - log.error("Failed to get connector index", e); - actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); - } else { - log.error("Failed to get ML connector " + connectorId, e); - actionListener.onFailure(e); - } - }), context::restore)); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + context, + getDataObjectRequest, + connectorId, + ActionListener + .wrap( + connector -> handleConnectorAccessValidation(user, tenantId, connector, actionListener), + e -> handleConnectorAccessValidationFailure(connectorId, e, actionListener) + ) + ); } catch (Exception e) { - log.error("Failed to get ML connector " + connectorId, e); + log.error("Failed to get ML connector {}", connectorId, e); actionListener.onFailure(e); } + } + + private void handleConnectorAccessValidation( + User user, + String tenantId, + Connector mlConnector, + ActionListener actionListener + ) { + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlConnector.getTenantId(), actionListener)) { + if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { + actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } + } + } + private void handleConnectorAccessValidationFailure( + String connectorId, + Exception e, + ActionListener actionListener + ) { + log.error("Failed to get ML connector: {}", connectorId, e); + actionListener.onFailure(e); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 92b087f686..1c2aa3a2d5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -6,25 +6,24 @@ package org.opensearch.ml.action.connector; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import java.io.IOException; import java.time.Instant; import java.util.HashSet; import java.util.List; import org.opensearch.action.ActionRequest; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; @@ -41,7 +40,12 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -51,8 +55,11 @@ public class TransportCreateConnectorAction extends HandledTransportAction { private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final SdkClient sdkClient; private final MLEngine mlEngine; private final MLModelManager mlModelManager; + + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; private final ConnectorAccessControlHelper connectorAccessControlHelper; private volatile List trustedConnectorEndpointsRegex; @@ -63,18 +70,22 @@ public TransportCreateConnectorAction( ActionFilters actionFilters, MLIndicesHandler mlIndicesHandler, Client client, + SdkClient sdkClient, MLEngine mlEngine, ConnectorAccessControlHelper connectorAccessControlHelper, Settings settings, ClusterService clusterService, - MLModelManager mlModelManager + MLModelManager mlModelManager, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLCreateConnectorAction.NAME, transportService, actionFilters, MLCreateConnectorRequest::new); this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.sdkClient = sdkClient; this.mlEngine = mlEngine; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -85,6 +96,9 @@ public TransportCreateConnectorAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request); MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, mlCreateConnectorInput.getTenantId(), listener)) { + return; + } if (mlCreateConnectorInput.isDryRun()) { MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); listener.onResponse(response); @@ -115,7 +129,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener indexResponseListener = ActionListener.wrap(r -> { - log.info("Connector saved into index, result:{}, connector id: {}", r.getResult(), r.getId()); - MLCreateConnectorResponse response = new MLCreateConnectorResponse(r.getId()); - listener.onResponse(response); - }, listener::onFailure); - Instant currentTime = Instant.now(); connector.setCreatedTime(currentTime); connector.setLastUpdateTime(currentTime); - - IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX); - indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore)); + sdkClient + .putDataObjectAsync( + PutDataObjectRequest + .builder() + .tenantId(connector.getTenantId()) + .index(ML_CONNECTOR_INDEX) + .dataObject(connector) + .build(), + client.threadPool().executor(GENERAL_THREAD_POOL) + ) + .whenComplete((r, throwable) -> { + context.restore(); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to create ML connector", cause); + listener.onFailure(cause); + } else { + try { + IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); + log + .info( + "Connector creation result: {}, connector id: {}", + indexResponse.getResult(), + indexResponse.getId() + ); + listener.onResponse(new MLCreateConnectorResponse(indexResponse.getId())); + } catch (IOException e) { + listener.onFailure(e); + } + } + }); } catch (Exception e) { log.error("Failed to save ML connector", e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index b1096e7e38..7b7f68e8c0 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -7,22 +7,29 @@ package org.opensearch.ml.helper; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -32,9 +39,15 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import lombok.extern.log4j.Log4j2; @@ -70,13 +83,51 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis getConnector(client, connectorId, ActionListener.wrap(connector -> { boolean hasPermission = hasPermission(user, connector); wrappedListener.onResponse(hasPermission); - }, e -> { wrappedListener.onFailure(e); })); + }, wrappedListener::onFailure)); } catch (Exception e) { log.error("Failed to validate Access for connector:" + connectorId, e); listener.onFailure(e); } } + public void validateConnectorAccess( + SdkClient sdkClient, + Client client, + String connectorId, + String tenantId, + MLFeatureEnabledSetting mlFeatureEnabledSetting, + ActionListener listener + ) { + + User user = RestActionUtils.getUserContext(client); + if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + if (isAdmin(user) || accessControlNotEnabled(user)) { + listener.onResponse(true); + return; + } + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + FetchSourceContext fetchSourceContext = getFetchSourceContext(true); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_CONNECTOR_INDEX) + .tenantId(tenantId) + .id(connectorId) + .fetchSourceContext(fetchSourceContext) + .build(); + getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> { + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, connector.getTenantId(), listener)) { + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + } + }, wrappedListener::onFailure)); + } catch (Exception e) { + log.error("Failed to validate Access for connector:{}", connectorId, e); + listener.onFailure(e); + } + } + public boolean validateConnectorAccess(Client client, Connector connector) { User user = RestActionUtils.getUserContext(client); if (isAdmin(user) || accessControlNotEnabled(user)) { @@ -85,6 +136,8 @@ public boolean validateConnectorAccess(Client client, Connector connector) { return hasPermission(user, connector); } + // TODO will remove this method in favor of other getConnector method. This method is still being used in update model/update connect. + // I'll remove this method when I'll refactor update methods. public void getConnector(Client client, String connectorId, ActionListener listener) { GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); client.get(getRequest, ActionListener.wrap(r -> { @@ -109,6 +162,71 @@ public void getConnector(Client client, String connectorId, ActionListener listener + ) { + + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + context.restore(); + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (cause instanceof IndexNotFoundException) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML connector " + connectorId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + listener.onResponse(mlConnector); + } catch (Exception e) { + log.error("Failed to parse ml connector {}", r.id(), e); + listener.onFailure(e); + } + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: " + connectorId, + RestStatus.NOT_FOUND + ) + ); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + }); + + } + public boolean skipConnectorAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9dc60998db..053f5e963a 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -732,7 +732,8 @@ public Collection createComponents( clusterManagerEventListener, mlCircuitBreakerService, mlModelAutoRedeployer, - cmHandler + cmHandler, + sdkClient ); } @@ -776,7 +777,7 @@ public List getRestHandlers( RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); - RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); + RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java index a1e05ce7d6..177de35152 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -71,6 +72,8 @@ MLCreateConnectorRequest getRequest(RestRequest request) throws IOException { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.parse(parser); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + mlCreateConnectorInput.setTenantId(tenantId); return new MLCreateConnectorRequest(mlCreateConnectorInput); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java index 532cd26123..d72148bb66 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -15,6 +16,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,7 +29,11 @@ public class RestMLDeleteConnectorAction extends BaseRestHandler { private static final String ML_DELETE_CONNECTOR_ACTION = "ml_delete_connector_action"; - public void RestMLDeleteConnectorAction() {} + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLDeleteConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -45,8 +51,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String connectorId = request.param(PARAMETER_CONNECTOR_ID); - - MLConnectorDeleteRequest mlConnectorDeleteRequest = new MLConnectorDeleteRequest(connectorId); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + MLConnectorDeleteRequest mlConnectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId); return channel -> client.execute(MLConnectorDeleteAction.INSTANCE, mlConnectorDeleteRequest, new RestToXContentListener<>(channel)); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 66b0f163f2..91476cf5b2 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -308,13 +308,13 @@ private MLCommonsSettings() {} /** This setting sets the remote metadata endpoint */ public static final Setting REMOTE_METADATA_ENDPOINT = Setting - .simpleString("plugins.flow_framework." + REMOTE_METADATA_ENDPOINT_KEY, Setting.Property.NodeScope, Setting.Property.Final); + .simpleString("plugins.ml_commons." + REMOTE_METADATA_ENDPOINT_KEY, Setting.Property.NodeScope, Setting.Property.Final); /** This setting sets the remote metadata region */ public static final Setting REMOTE_METADATA_REGION = Setting - .simpleString("plugins.flow_framework." + REMOTE_METADATA_REGION_KEY, Setting.Property.NodeScope, Setting.Property.Final); + .simpleString("plugins.ml_commons." + REMOTE_METADATA_REGION_KEY, Setting.Property.NodeScope, Setting.Property.Final); /** This setting sets the remote metadata service name */ public static final Setting REMOTE_METADATA_SERVICE_NAME = Setting - .simpleString("plugins.flow_framework." + REMOTE_METADATA_SERVICE_NAME_KEY, Setting.Property.NodeScope, Setting.Property.Final); + .simpleString("plugins.ml_commons." + REMOTE_METADATA_SERVICE_NAME_KEY, Setting.Property.NodeScope, Setting.Property.Final); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index 8a97c87e61..7884c4cbaf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -6,14 +6,24 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.DELETED; +import static org.opensearch.action.DocWriteResponse.Result.NOT_FOUND; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -22,19 +32,25 @@ import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -44,21 +60,42 @@ import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class DeleteConnectorTransportActionTests extends OpenSearchTestCase { + + private static final String CONNECTOR_ID = "connector_id"; + + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportCreateConnectorActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock ThreadPool threadPool; @Mock Client client; + SdkClient sdkClient; + @Mock TransportService transportService; @@ -85,99 +122,146 @@ public class DeleteConnectorTransportActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId("connector_id").build(); - + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(CONNECTOR_ID).build(); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); Settings settings = Settings.builder().build(); deleteConnectorTransportAction = spy( - new DeleteConnectorTransportAction(transportService, actionFilters, client, xContentRegistry, connectorAccessControlHelper) + new DeleteConnectorTransportAction( + transportService, + actionFilters, + client, + sdkClient, + xContentRegistry, + connectorAccessControlHelper, + mlFeatureEnabledSetting + ) ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(true); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); } - public void testDeleteConnector_Success() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + public void testDeleteConnector_Success() throws InterruptedException { + DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(deleteResponse); + when(client.delete(any(DeleteRequest.class))).thenReturn(future); - SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(searchResponse); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).search(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + SearchResponse searchResponse = getEmptySearchResponse(); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(CONNECTOR_ID, captor.getValue().getId()); + assertEquals(DELETED, captor.getValue().getResult()); } - public void testDeleteConnector_ModelIndexNotFoundSuccess() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); + public void testDeleteConnector_ModelIndexNotFoundSuccess() throws InterruptedException { + DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(deleteResponse); + when(client.delete(any(DeleteRequest.class))).thenReturn(future); - SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IndexNotFoundException("ml_model index not found!")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).search(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onFailure(new IndexNotFoundException("ml_model index not found!")); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(CONNECTOR_ID, captor.getValue().getId()); + assertEquals(DELETED, captor.getValue().getResult()); } - public void testDeleteConnector_ConnectorNotFound() throws IOException { - when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + public void testDeleteConnector_ConnectorNotFound() throws InterruptedException { + DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, false); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(deleteResponse); + when(client.delete(any(DeleteRequest.class))).thenReturn(future); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).delete(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); SearchResponse searchResponse = getEmptySearchResponse(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(CONNECTOR_ID, captor.getValue().getId()); + assertEquals(NOT_FOUND, captor.getValue().getResult()); } - public void testDeleteConnector_BlockedByModel() throws IOException { + public void testDeleteConnector_BlockedByModel() throws IOException, InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).delete(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); SearchResponse searchResponse = getNonEmptySearchResponse(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -188,10 +272,10 @@ public void testDeleteConnector_BlockedByModel() throws IOException { public void test_UserHasNoAccessException() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -199,67 +283,72 @@ public void test_UserHasNoAccessException() throws IOException { assertEquals("You are not allowed to delete this connector", argumentCaptor.getValue().getMessage()); } - public void testDeleteConnector_SearchFailure() throws IOException { + public void testDeleteConnector_SearchFailure() throws InterruptedException { doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new RuntimeException("Search Failed!")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).search(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new ResourceNotFoundException("errorMessage")); - return null; - }).when(client).delete(any(), any()); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onFailure(new RuntimeException("Search Failed!")); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Search Failed!", argumentCaptor.getValue().getMessage()); } - public void testDeleteConnector_SearchException() throws IOException { + public void testDeleteConnector_SearchException() { when(client.threadPool()).thenThrow(new RuntimeException("Thread Context Error!")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Thread Context Error!", argumentCaptor.getValue().getMessage()); } - public void testDeleteConnector_ResourceNotFoundException() throws IOException { - SearchResponse searchResponse = getEmptySearchResponse(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + public void testDeleteConnector_ResourceNotFoundException() throws InterruptedException { + when(client.delete(any(DeleteRequest.class))).thenThrow(new ResourceNotFoundException("errorMessage")); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new ResourceNotFoundException("errorMessage")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); return null; - }).when(client).delete(any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + + SearchResponse searchResponse = getEmptySearchResponse(); + PlainActionFuture searchFuture = PlainActionFuture.newFuture(); + searchFuture.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } - public void test_ValidationFailedException() throws IOException { - GetResponse getResponse = prepareMLConnector(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).search(any(), any()); - + public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -267,6 +356,23 @@ public void test_ValidationFailedException() throws IOException { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + public void testDeleteConnector_MultiTenancyEnabled_NoTenantId() throws InterruptedException { + // Enable multi-tenancy + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + // Create a request without a tenant ID + MLConnectorDeleteRequest requestWithoutTenant = MLConnectorDeleteRequest.builder().connectorId(CONNECTOR_ID).build(); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + deleteConnectorTransportAction.doExecute(null, requestWithoutTenant, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage()); + } + public GetResponse prepareMLConnector() throws IOException { HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").build(); XContentBuilder content = connector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index c2cbf81cf1..9626b75959 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -6,49 +6,84 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.ExpectedException; +import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class GetConnectorTransportActionTests extends OpenSearchTestCase { + private static final String CONNECTOR_ID = "connector_id"; + + private static final String TENANT_ID = "_tenant_id"; + + private static final TestThreadPool testThreadPool = new TestThreadPool( + GetConnectorTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock ThreadPool threadPool; @Mock Client client; - @Mock - NamedXContentRegistry xContentRegistry; + SdkClient sdkClient; @Mock TransportService transportService; @@ -59,114 +94,158 @@ public class GetConnectorTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; + @Mock + GetResponse getResponse; + @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); GetConnectorTransportAction getConnectorTransportAction; MLConnectorGetRequest mlConnectorGetRequest; ThreadContext threadContext; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Captor + private ArgumentCaptor getDataObjectRequestArgumentCaptor; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId("connector_id").build(); + Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(TENANT_ID).build(); + when(getResponse.getId()).thenReturn(CONNECTOR_ID); + when(getResponse.getSourceAsString()).thenReturn("{}"); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); getConnectorTransportAction = spy( - new GetConnectorTransportAction(transportService, actionFilters, client, xContentRegistry, connectorAccessControlHelper) + new GetConnectorTransportAction( + transportService, + actionFilters, + client, + sdkClient, + connectorAccessControlHelper, + mlFeatureEnabledSetting + ) ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); - threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); } - public void testGetConnector_UserHasNodeAccess() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(false); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); - - GetResponse getResponse = prepareConnector(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } - public void testGetConnector_ValidateAccessFailed() throws IOException { + @Test + public void testGetConnector_UserHasNoAccess() throws IOException, InterruptedException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId("tenantId").build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new Exception("Failed to validate access")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); - GetResponse getResponse = prepareConnector(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); + GetResponse getResponse = prepareConnector(null); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); } - public void testGetConnector_NullResponse() { + @Test + public void testGetConnector_NullResponse() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(null); + ActionListener listener = invocation.getArgument(5); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: connector_id", + RestStatus.NOT_FOUND + ) + ); return null; - }).when(client).get(any(), any()); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find connector with the provided connector id: connector_id", argumentCaptor.getValue().getMessage()); } - public void testGetConnector_IndexNotFoundException() { + public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); + String tenantId = "test_tenant"; + mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(tenantId).build(); + + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(tenantId).build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IndexNotFoundException("Fail to find model")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); return null; - }).when(client).get(any(), any()); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), + getDataObjectRequestArgumentCaptor.capture(), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + Assert.assertEquals(tenantId, getDataObjectRequestArgumentCaptor.getValue().tenantId()); + Assert.assertEquals(CONNECTOR_ID, getDataObjectRequestArgumentCaptor.getValue().id()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorGetResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(tenantId, argumentCaptor.getValue().getMlConnector().getTenantId()); } - public void testGetConnector_RuntimeException() { + @Test + public void testGetConnector_MultiTenancyEnabled_ForbiddenAccess() throws IOException, InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); + String tenantId = "test_tenant"; + mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(tenantId).build(); + + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId("tenantId").build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("errorMessage")); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); return null; - }).when(client).get(any(), any()); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareConnector() throws IOException { - HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").build(); + public GetResponse prepareConnector(String tenantId) throws IOException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(tenantId).build(); XContentBuilder content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 1a0cd7716f..9709cbb0ce 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -5,35 +5,51 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; @@ -44,8 +60,14 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -54,12 +76,27 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { + private static final String CONNECTOR_ID = "connector_id"; + private static final String TENANT_ID = "_tenant_id"; + + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportCreateConnectorActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + private TransportCreateConnectorAction action; @Mock private MLIndicesHandler mlIndicesHandler; @Mock private Client client; + private SdkClient sdkClient; @Mock private MLEngine mlEngine; @Mock @@ -82,6 +119,11 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; + IndexResponse indexResponse; + + @Mock + NamedXContentRegistry xContentRegistry; + @Mock private ThreadPool threadPool; @@ -92,32 +134,47 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private Settings settings; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Captor + private ArgumentCaptor putDataObjectRequestArgumentCaptor; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings .builder() .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + indexResponse = new IndexResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); + ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + action = new TransportCreateConnectorAction( transportService, actionFilters, mlIndicesHandler, client, + sdkClient, mlEngine, connectorAccessControlHelper, settings, clusterService, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); @@ -125,6 +182,7 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); List actions = new ArrayList<>(); actions @@ -151,7 +209,12 @@ public void setup() { when(request.getMlCreateConnectorInput()).thenReturn(input); } - public void test_execute_connectorAccessControl_notEnabled_success() { + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + public void test_execute_connectorAccessControl_notEnabled_success() throws InterruptedException { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); input.setAddAllBackendRoles(null); input.setBackendRoles(null); @@ -162,19 +225,23 @@ public void test_execute_connectorAccessControl_notEnabled_success() { return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - action.doExecute(task, request, actionListener); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(indexResponse); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + action.doExecute(task, request, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); } - public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_exception() { + public void test_execute_connector_registration_multi_tenancy_fail() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); + input.setAddAllBackendRoles(null); input.setBackendRoles(null); - input.setAddAllBackendRoles(true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -182,21 +249,52 @@ public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_ex return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(indexResponse); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + action.doExecute(task, request, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to access this resource", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_exception() throws InterruptedException { + when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); + input.setBackendRoles(null); + input.setAddAllBackendRoles(true); + doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - action.doExecute(task, request, actionListener); + }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(mock(IndexResponse.class)); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + action.doExecute(task, request, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster.", - argumentCaptor.getValue().getMessage() + "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster.", + argumentCaptor.getValue().getMessage() ); } - public void test_execute_connectorAccessControlEnabled_success() { + public void test_execute_connectorAccessControlEnabled_success() throws InterruptedException { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); input.setAddAllBackendRoles(false); input.setBackendRoles(ImmutableList.of("role1", "role2")); @@ -207,16 +305,20 @@ public void test_execute_connectorAccessControlEnabled_success() { return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - action.doExecute(task, request, actionListener); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(indexResponse); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + request.getMlCreateConnectorInput().setTenantId(TENANT_ID); + action.doExecute(task, request, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); } - public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() { + public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() throws InterruptedException { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); input.setAddAllBackendRoles(null); input.setBackendRoles(null); @@ -227,12 +329,15 @@ public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_def return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - action.doExecute(task, request, actionListener); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(indexResponse); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + action.doExecute(task, request, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); } @@ -248,16 +353,11 @@ public void test_execute_connectorAccessControlEnabled_adminSpecifyAllBackendRol return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "Admin can't add all backend roles", argumentCaptor.getValue().getMessage() + "Admin can't add all backend roles", argumentCaptor.getValue().getMessage() ); } @@ -273,17 +373,12 @@ public void test_execute_connectorAccessControlEnabled_specifyBackendRolesForPub return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "You can specify backend roles only for a connector with the restricted access mode.", - argumentCaptor.getValue().getMessage() + "You can specify backend roles only for a connector with the restricted access mode.", + argumentCaptor.getValue().getMessage() ); } @@ -306,21 +401,18 @@ public void test_execute_connectorAccessControlEnabled_userNoBackendRoles_except return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters, mlIndicesHandler, client, + sdkClient, mlEngine, connectorAccessControlHelper, settings, clusterService, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -347,21 +439,18 @@ public void test_execute_connectorAccessControlEnabled_parameterConflict_excepti return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters, mlIndicesHandler, client, + sdkClient, mlEngine, connectorAccessControlHelper, settings, clusterService, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -391,21 +480,18 @@ public void test_execute_connectorAccessControlEnabled_specifyNotBelongedRole_ex return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters, mlIndicesHandler, client, + sdkClient, mlEngine, connectorAccessControlHelper, settings, clusterService, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -424,12 +510,6 @@ public void test_execute_dryRun_connector_creation() { return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(IndexResponse.class)); - return null; - }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - MLCreateConnectorInput mlCreateConnectorInput = mock(MLCreateConnectorInput.class); when(mlCreateConnectorInput.getName()).thenReturn(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); when(mlCreateConnectorInput.isDryRun()).thenReturn(true); @@ -469,11 +549,13 @@ public void test_execute_URL_notMatchingExpression_exception() { actionFilters, mlIndicesHandler, client, + sdkClient, mlEngine, connectorAccessControlHelper, settings, clusterService, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ); action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 30c9f6191c..06f86dadc9 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -8,43 +8,67 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -60,6 +84,9 @@ public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { @Mock private ActionListener actionListener; + @Mock + private ActionListener getConnectorActionListener; + @Mock private ThreadPool threadPool; @@ -71,14 +98,35 @@ public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { private User user; + SdkClient sdkClient; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + + private static TestThreadPool testThreadPool = new TestThreadPool( + ConnectorAccessControlHelperTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Before public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + connectorAccessControlHelper = spy(new ConnectorAccessControlHelper(clusterService, settings)); user = User.parse("mockUser|role-1,role-2|null"); getResponse = createGetResponse(null); @@ -90,14 +138,22 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } + @Test public void test_hasPermission_user_null_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); boolean hasPermission = connectorAccessControlHelper.hasPermission(null, httpConnector); assertTrue(hasPermission); } + @Test public void test_hasPermission_connectorAccessControl_not_enabled_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); @@ -108,6 +164,7 @@ public void test_hasPermission_connectorAccessControl_not_enabled_return_true() assertTrue(hasPermission); } + @Test public void test_hasPermission_connectorOwner_is_null_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getOwner()).thenReturn(null); @@ -115,12 +172,14 @@ public void test_hasPermission_connectorOwner_is_null_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_user_is_admin_return_true() { User user = User.parse("admin|role-1|all_access"); boolean hasPermission = connectorAccessControlHelper.hasPermission(user, mock(HttpConnector.class)); assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPublic_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PUBLIC); @@ -128,6 +187,7 @@ public void test_hasPermission_connector_isPublic_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); @@ -136,6 +196,7 @@ public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); @@ -145,6 +206,7 @@ public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() assertFalse(hasPermission); } + @Test public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); @@ -153,6 +215,7 @@ public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_ assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_return_false() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); @@ -162,7 +225,8 @@ public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_retu assertFalse(hasPermission); } - public void test_validateConnectorAccess_user_isAdmin_return_true() { + // todo: will remove this later + public void test_validateConnectorAccess_user_isAdmin_return_true_old() { String userString = "admin|role-1|all_access"; Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); ThreadContext threadContext = new ThreadContext(settings); @@ -174,7 +238,21 @@ public void test_validateConnectorAccess_user_isAdmin_return_true() { verify(actionListener).onResponse(true); } - public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false() { + @Test + public void test_validateConnectorAccess_user_isAdmin_return_true() { + String userString = "admin|role-1|all_access"; + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); + + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener).onResponse(true); + } + + // todo will remove later. + public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false_old() { GetResponse getResponse = createGetResponse(ImmutableList.of("role-3")); Client client = mock(Client.class); doAnswer(invocation -> { @@ -190,12 +268,67 @@ public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return verify(actionListener).onResponse(false); } + @Test + public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false() throws Exception { + // Mock the client thread pool + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + // Set up user context + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + // Create HttpConnector + HttpConnector httpConnector = HttpConnector.builder() + .name("testConnector") + .protocol(ConnectorProtocols.HTTP) + .owner(user) + .description("This is test connector") + .backendRoles(Collections.singletonList("role-3")) + .accessMode(AccessMode.RESTRICTED) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + + // Execute the validation + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + + // Verify the action listener was called with false + verify(actionListener).onResponse(false); + } + + @Test public void test_validateConnectorAccess_user_isNotAdmin_hasBackendRole_return_true() { + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener).onResponse(true); + } + + // todo will remove later + public void test_validateConnectorAccess_user_isNotAdmin_hasBackendRole_return_true_old() { connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); verify(actionListener).onResponse(true); } + @Test public void test_validateConnectorAccess_connectorNotFound_return_false() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + // connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); + } + + // todo will remove later + public void test_validateConnectorAccess_connectorNotFound_return_false_old() { Client client = mock(Client.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -210,7 +343,24 @@ public void test_validateConnectorAccess_connectorNotFound_return_false() { verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } + @Test public void test_validateConnectorAccess_searchConnectorException_return_false() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onFailure(new RuntimeException("Failed to find connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + // connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener, times(1)).onFailure(any(RuntimeException.class)); + } + + // todo will remove later + public void test_validateConnectorAccess_searchConnectorException_return_false_old() { Client client = mock(Client.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -225,11 +375,13 @@ public void test_validateConnectorAccess_searchConnectorException_return_false() verify(actionListener).onFailure(any(OpenSearchStatusException.class)); } + @Test public void test_skipConnectorAccessControl_userIsNull_return_true() { boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(null); assertTrue(skip); } + @Test public void test_skipConnectorAccessControl_connectorAccessControl_notEnabled_return_true() { Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); @@ -239,12 +391,14 @@ public void test_skipConnectorAccessControl_connectorAccessControl_notEnabled_re assertTrue(skip); } + @Test public void test_skipConnectorAccessControl_userIsAdmin_return_true() { User user = User.parse("admin|role-1|all_access"); boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(user); assertTrue(skip); } + @Test public void test_accessControlNotEnabled_connectorAccessControl_notEnabled_return_true() { Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); @@ -254,17 +408,20 @@ public void test_accessControlNotEnabled_connectorAccessControl_notEnabled_retur assertTrue(skip); } + @Test public void test_accessControlNotEnabled_userIsNull_return_true() { boolean notEnabled = connectorAccessControlHelper.accessControlNotEnabled(null); assertTrue(notEnabled); } + @Test public void test_addUserBackendRolesFilter_nullQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); SearchSourceBuilder result = connectorAccessControlHelper.addUserBackendRolesFilter(user, searchSourceBuilder); assertNotNull(result); } + @Test public void test_addUserBackendRolesFilter_boolQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new BoolQueryBuilder()); @@ -272,6 +429,7 @@ public void test_addUserBackendRolesFilter_boolQuery() { assertEquals("bool", result.query().getName()); } + @Test public void test_addUserBackendRolesFilter_nonBoolQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new MatchAllQueryBuilder()); @@ -279,6 +437,86 @@ public void test_addUserBackendRolesFilter_nonBoolQuery() { assertEquals("bool", result.query().getName()); } + @Test + public void testGetConnectorHappyCase() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); + GetResponse getResponse = prepareConnector(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); + verify(client, times(1)).get(requestCaptor.capture()); + assertEquals(CommonValue.ML_CONNECTOR_INDEX, requestCaptor.getValue().index()); + } + + @Test + public void testGetConnectorException() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("Failed to get connector")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetConnectorIndexNotFound() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new IndexNotFoundException("Index not found")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); + assertEquals(RestStatus.NOT_FOUND, argumentCaptor.getValue().status()); + } + private GetResponse createGetResponse(List backendRoles) { HttpConnector httpConnector = HttpConnector .builder() @@ -289,7 +527,7 @@ private GetResponse createGetResponse(List backendRoles) { .backendRoles(Optional.ofNullable(backendRoles).orElse(ImmutableList.of("role-1"))) .accessMode(AccessMode.RESTRICTED) .build(); - XContentBuilder content = null; + XContentBuilder content; try { content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); } catch (IOException e) { @@ -299,4 +537,13 @@ private GetResponse createGetResponse(List backendRoles) { GetResult getResult = new GetResult(CommonValue.ML_MODEL_GROUP_INDEX, "111", 111l, 111l, 111l, true, bytesReference, null, null); return new GetResponse(getResult); } + + public GetResponse prepareConnector() throws IOException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").build(); + XContentBuilder content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java index 07f823f905..552bc28232 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java @@ -104,7 +104,7 @@ public void testRoutes() { } public void testGetRequest() throws IOException { - RestRequest request = getCreateConnectorRestRequest(); + RestRequest request = getCreateConnectorRestRequest(null); MLCreateConnectorRequest mlCreateConnectorRequest = restMLCreateConnectorAction.getRequest(request); MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); @@ -112,7 +112,7 @@ public void testGetRequest() throws IOException { } public void testPrepareRequest() throws Exception { - RestRequest request = getCreateConnectorRestRequest(); + RestRequest request = getCreateConnectorRestRequest(null); restMLCreateConnectorAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateConnectorRequest.class); @@ -135,7 +135,17 @@ public void testPrepareRequestFeatureDisabled() throws Exception { thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); - RestRequest request = getCreateConnectorRestRequest(); + RestRequest request = getCreateConnectorRestRequest(null); restMLCreateConnectorAction.handleRequest(request, channel, client); } + + public void testGetRequest_MultiTenancyEnabled() throws IOException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + RestRequest request = getCreateConnectorRestRequest("tenantId"); + MLCreateConnectorRequest mlCreateConnectorRequest = restMLCreateConnectorAction.getRequest(request); + + MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); + verifyParsedCreateConnectorInput(mlCreateConnectorInput); + assertEquals("tenantId", mlCreateConnectorInput.getTenantId()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java index bce92d9b69..4adb29c1b7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java @@ -7,27 +7,30 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.*; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.input.Constants; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -37,6 +40,8 @@ import org.opensearch.threadpool.ThreadPool; public class RestMLDeleteConnectorActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); private RestMLDeleteConnectorAction restMLDeleteConnectorAction; @@ -46,9 +51,14 @@ public class RestMLDeleteConnectorActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -68,7 +78,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLDeleteConnectorAction mlDeleteConnectorAction = new RestMLDeleteConnectorAction(); + RestMLDeleteConnectorAction mlDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); assertNotNull(mlDeleteConnectorAction); } @@ -88,7 +98,7 @@ public void testRoutes() { } public void test_PrepareRequest() throws Exception { - RestRequest request = getRestRequest(); + RestRequest request = getRestRequest("connector_id", null); restMLDeleteConnectorAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorDeleteRequest.class); @@ -97,10 +107,27 @@ public void test_PrepareRequest() throws Exception { assertEquals(connectorId, "connector_id"); } - private RestRequest getRestRequest() { + public void testPrepareRequest_MultiTenancyEnabled() throws Exception { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + RestRequest request = getRestRequest("connector_id", "_tenant_id"); + restMLDeleteConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorDeleteRequest.class); + verify(client, times(1)).execute(eq(MLConnectorDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + MLConnectorDeleteRequest mlConnectorDeleteRequest = argumentCaptor.getValue(); + assertEquals("connector_id", mlConnectorDeleteRequest.getConnectorId()); + assertEquals("_tenant_id", mlConnectorDeleteRequest.getTenantId()); + } + + private RestRequest getRestRequest(String connectorId, String tenantId) { Map params = new HashMap<>(); - params.put(PARAMETER_CONNECTOR_ID, "connector_id"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); - return request; + params.put(PARAMETER_CONNECTOR_ID, connectorId); + + Map> headers = new HashMap<>(); + if (tenantId != null) { + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + } + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).withHeaders(headers).build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 895345825f..0796fff279 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -76,6 +76,7 @@ import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.Constants; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; @@ -245,7 +246,12 @@ public static RestRequest getBatchRestRequest_WrongActionType() { return request; } - public static RestRequest getCreateConnectorRestRequest() { + public static RestRequest getCreateConnectorRestRequest(String tenantId) { + Map> headers = new HashMap<>(); + if (tenantId != null) { + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + } + final String requestContent = "{\n" + " \"name\": \"OpenAI Connector\",\n" + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" @@ -276,6 +282,7 @@ public static RestRequest getCreateConnectorRestRequest() { + " \"access_mode\": \"public\"\n" + "}"; RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withHeaders(headers) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); return request;