From 4f6a40659b1460d03fbd5b1a38b4c6cf85f10b15 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 10 Jan 2025 18:36:14 -0500 Subject: [PATCH] [Manual Backport to 2.x] Primary setup for Multi-tenancy (#3307) (#3366) * Primary setup for Multi-tenancy (#3307) * multi-tenancy primary setup Signed-off-by: Dhrubo Saha * addressed comments Signed-off-by: Dhrubo Saha * addressed comments + fixed dependency issue Signed-off-by: Dhrubo Saha * adding more log to debug the testVisualizationFound issue Signed-off-by: Dhrubo Saha * changing back Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha * apply spotless Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha --- .github/workflows/CI-workflow.yml | 1 - .../org/opensearch/ml/common/CommonValue.java | 4 + .../common/connector/AbstractConnector.java | 2 + .../ml/common/connector/Connector.java | 4 + .../opensearch/ml/common/input/Constants.java | 1 + .../settings/SettingsChangeListener.java | 22 +++ .../connector/MLConnectorGetRequest.java | 14 +- .../resources/index-mappings/ml-agent.json | 5 +- .../resources/index-mappings/ml-config.json | 5 +- .../index-mappings/ml-connector.json | 5 +- .../index-mappings/ml-model-group.json | 5 +- .../resources/index-mappings/ml-model.json | 5 +- .../resources/index-mappings/ml-task.json | 5 +- gradle/wrapper/gradle-wrapper.properties | 1 + ml-algorithms/build.gradle | 2 + .../engine/indices/MLIndicesHandlerTest.java | 4 +- plugin/build.gradle | 13 ++ .../GetConnectorTransportAction.java | 12 ++ .../ml/plugin/MachineLearningPlugin.java | 42 +++- .../ml/rest/RestMLGetConnectorAction.java | 20 +- .../ml/settings/MLCommonsSettings.java | 54 +++++ .../ml/settings/MLFeatureEnabledSetting.java | 31 +++ .../ml/utils/TenantAwareHelper.java | 99 ++++++++++ .../action/prediction/PredictionITTests.java | 20 +- .../rest/RestMLGetConnectorActionTests.java | 28 ++- .../MLFeatureEnabledSettingTests.java | 76 ++++++++ .../ml/utils/TenantAwareHelperTests.java | 184 ++++++++++++++++++ 27 files changed, 636 insertions(+), 28 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java create mode 100644 plugin/src/main/java/org/opensearch/ml/utils/TenantAwareHelper.java create mode 100644 plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/utils/TenantAwareHelperTests.java diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 1ab56fdb95..6bdc1250be 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -52,7 +52,6 @@ jobs: with: role-to-assume: ${{ secrets.ML_ROLE }} aws-region: us-west-2 - - name: Checkout MLCommons uses: actions/checkout@v4 with: diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index e06e552536..ef6c067b05 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -21,6 +21,9 @@ public class CommonValue { public static final String UNDEPLOYED = "undeployed"; public static final String NOT_FOUND = "not_found"; + /** The field name containing the tenant id */ + public static final String TENANT_ID_FIELD = "tenant_id"; + public static final String MASTER_KEY = "master_key"; public static final String CREATE_TIME_FIELD = "create_time"; public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; @@ -63,4 +66,5 @@ public class CommonValue { public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); public static final Version VERSION_2_18_0 = Version.fromString("2.18.0"); + public static final Version VERSION_2_19_0 = Version.fromString("2.19.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index d8adc7ac54..7cf45d8d26 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -71,6 +71,8 @@ public abstract class AbstractConnector implements Connector { protected Instant lastUpdateTime; @Setter protected ConnectorClientConfig connectorClientConfig; + @Setter + protected String tenantId; protected Map createDecryptedHeaders(Map headers) { if (headers == null) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 86068ad0f9..1bdb6747d1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -43,6 +43,10 @@ public interface Connector extends ToXContentObject, Writeable { String getName(); + String getTenantId(); + + void setTenantId(String tenantId); + String getProtocol(); void setCreatedTime(Instant createdTime); diff --git a/common/src/main/java/org/opensearch/ml/common/input/Constants.java b/common/src/main/java/org/opensearch/ml/common/input/Constants.java index 256f73f7c1..16d9c05438 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/Constants.java +++ b/common/src/main/java/org/opensearch/ml/common/input/Constants.java @@ -36,4 +36,5 @@ public class Constants { public static final String AD_TRAINING_DATA_SIZE = "trainingDataSize"; public static final String AD_ANOMALY_SCORE_THRESHOLD = "anomalyScoreThreshold"; public static final String AD_DATE_FORMAT = "dateFormat"; + public static final String TENANT_ID_HEADER = "x-tenant-id"; } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java b/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java new file mode 100644 index 0000000000..946e88239a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.settings; + +/** + * Interface for handling settings changes in the OpenSearch ML plugin. + */ +public interface SettingsChangeListener { + /** + * Callback method that gets triggered when the multi-tenancy setting changes. + * + * @param isEnabled A boolean value indicating the new state of the multi-tenancy setting: + *
    + *
  • true if multi-tenancy is enabled
  • + *
  • false if multi-tenancy is disabled
  • + *
+ */ + void onMultiTenancyEnabledChanged(boolean isEnabled); +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java index 53c6c9c497..c8a89ea4a5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.connector; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -26,24 +28,34 @@ public class MLConnectorGetRequest extends ActionRequest { String connectorId; + String tenantId; boolean returnContent; @Builder - public MLConnectorGetRequest(String connectorId, boolean returnContent) { + public MLConnectorGetRequest(String connectorId, String tenantId, boolean returnContent) { this.connectorId = connectorId; + this.tenantId = tenantId; this.returnContent = returnContent; } public MLConnectorGetRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.connectorId = in.readString(); + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = in.readOptionalString(); + } this.returnContent = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeString(this.connectorId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(this.tenantId); + } out.writeBoolean(returnContent); } diff --git a/common/src/main/resources/index-mappings/ml-agent.json b/common/src/main/resources/index-mappings/ml-agent.json index 2bcee6bc3b..9d4deeca51 100644 --- a/common/src/main/resources/index-mappings/ml-agent.json +++ b/common/src/main/resources/index-mappings/ml-agent.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 2 + "schema_version": 3 }, "properties": { "name": { @@ -33,6 +33,9 @@ "is_hidden": { "type": "boolean" }, + "tenant_id": { + "type": "keyword" + }, "created_time": { "type": "date", "format": "strict_date_time||epoch_millis" diff --git a/common/src/main/resources/index-mappings/ml-config.json b/common/src/main/resources/index-mappings/ml-config.json index 6d36d8efb7..8882814df0 100644 --- a/common/src/main/resources/index-mappings/ml-config.json +++ b/common/src/main/resources/index-mappings/ml-config.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 4 + "schema_version": 5 }, "properties": { "master_key": { @@ -9,6 +9,9 @@ "config_type": { "type": "keyword" }, + "tenant_id": { + "type": "keyword" + }, "ml_configuration": { "type": "flat_object" }, diff --git a/common/src/main/resources/index-mappings/ml-connector.json b/common/src/main/resources/index-mappings/ml-connector.json index 4be168c4b9..2eecf36f5c 100644 --- a/common/src/main/resources/index-mappings/ml-connector.json +++ b/common/src/main/resources/index-mappings/ml-connector.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "name": { @@ -30,6 +30,9 @@ "client_config": { "type": "flat_object" }, + "tenant_id": { + "type": "keyword" + }, "actions": { "type": "flat_object" }, diff --git a/common/src/main/resources/index-mappings/ml-model-group.json b/common/src/main/resources/index-mappings/ml-model-group.json index 7e2437e534..bd05dd55f3 100644 --- a/common/src/main/resources/index-mappings/ml-model-group.json +++ b/common/src/main/resources/index-mappings/ml-model-group.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 2 + "schema_version": 3 }, "properties": { "name": { @@ -21,6 +21,9 @@ "model_group_id": { "type": "keyword" }, + "tenant_id": { + "type": "keyword" + }, "backend_roles": { "type": "text", "fields": { diff --git a/common/src/main/resources/index-mappings/ml-model.json b/common/src/main/resources/index-mappings/ml-model.json index b996e463cd..8b92b71249 100644 --- a/common/src/main/resources/index-mappings/ml-model.json +++ b/common/src/main/resources/index-mappings/ml-model.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 11 + "schema_version": 12 }, "properties": { "algorithm": { @@ -63,6 +63,9 @@ "is_hidden": { "type": "boolean" }, + "tenant_id": { + "type": "keyword" + }, "model_config": { "properties": { "model_type": { diff --git a/common/src/main/resources/index-mappings/ml-task.json b/common/src/main/resources/index-mappings/ml-task.json index ad428724bf..ca71b22906 100644 --- a/common/src/main/resources/index-mappings/ml-task.json +++ b/common/src/main/resources/index-mappings/ml-task.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "model_id": { @@ -38,6 +38,9 @@ "error": { "type": "text" }, + "tenant_id": { + "type": "keyword" + }, "is_async": { "type": "boolean" }, diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index cbaae54fa2..142d400f41 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -7,6 +7,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists distributionUrl=https\://services.gradle.org/distributions/gradle-8.11.1-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists distributionSha256Sum=f397b287023acdba1e9f6fc5ea72d22dd63669d59ed4a289a29b1a76eee151c6 diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index d208f9a78e..1d1b16d6ca 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -84,8 +84,10 @@ lombok { configurations.all { resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5' resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' + resolutionStrategy.force 'software.amazon.awssdk:bom:2.29.12' } + jacocoTestReport { reports { xml.getRequired().set(true) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java index 2026a203b9..683794ab54 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java @@ -95,8 +95,8 @@ public void setUp() { when(metadata.indices()).thenReturn(Map.of(ML_AGENT_INDEX, agentindexMetadata, ML_MEMORY_META_INDEX, memorymetaindexMetadata)); when(agentindexMetadata.mapping()).thenReturn(agentmappingMetadata); when(memorymetaindexMetadata.mapping()).thenReturn(memorymappingMetadata); - when(agentmappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(2)))); - when(memorymappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(2)))); + when(agentmappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, 3))); + when(memorymappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, 2))); settings = Settings.builder().put("test_key", 10).build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); diff --git a/plugin/build.gradle b/plugin/build.gradle index e1f5e225f9..f8a3fdf844 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -50,6 +50,9 @@ dependencies { implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" + // Multi-tenant SDK Client + implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_version}" + implementation "org.opensearch:common-utils:${common_utils_version}" implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") implementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") @@ -336,6 +339,7 @@ jacocoTestCoverageVerification { check.dependsOn jacocoTestCoverageVerification configurations.all { + exclude group: "org.jetbrains", module: "annotations" resolutionStrategy.force 'org.apache.commons:commons-lang3:3.10' resolutionStrategy.force 'commons-logging:commons-logging:1.2' resolutionStrategy.force 'org.objenesis:objenesis:3.2' @@ -348,6 +352,15 @@ configurations.all { resolutionStrategy.force 'org.slf4j:slf4j-api:1.7.36' resolutionStrategy.force 'org.codehaus.plexus:plexus-utils:3.3.0' exclude group: "org.jetbrains", module: "annotations" + resolutionStrategy.force "org.opensearch.client:opensearch-rest-client:${opensearch_version}" + resolutionStrategy.force "org.apache.httpcomponents.core5:httpcore5:5.3.1" + resolutionStrategy.force "org.apache.httpcomponents.core5:httpcore5-h2:5.3.1" + resolutionStrategy.force "org.apache.httpcomponents.client5:httpclient5:5.4.1" + resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" + resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${versions.jackson_databind}" + resolutionStrategy.force "org.apache.logging.log4j:log4j-api:2.24.2" + resolutionStrategy.force "org.apache.logging.log4j:log4j-core:2.24.2" + resolutionStrategy.force "jakarta.json:jakarta.json-api:2.1.3" } apply plugin: 'com.netflix.nebula.ospackage' diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index 7b953c341a..12611a942d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -10,6 +10,8 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import java.util.Objects; + import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -65,6 +67,7 @@ public GetConnectorTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.fromActionRequest(request); String connectorId = mlConnectorGetRequest.getConnectorId(); + String tenantId = mlConnectorGetRequest.getTenantId(); FetchSourceContext fetchSourceContext = getFetchSourceContext(mlConnectorGetRequest.isReturnContent()); GetRequest getRequest = new GetRequest(ML_CONNECTOR_INDEX).id(connectorId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); @@ -77,6 +80,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener createComponents( mlIndicesHandler = new MLIndicesHandler(clusterService, client); encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + SdkClient sdkClient = SdkClientFactory + .createSdkClient( + client, + xContentRegistry, + // Here we assume remote metadata client is only used with tenant awareness. + // This may change in the future allowing more options for this map + ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings) + ? Map + .ofEntries( + Map.entry(REMOTE_METADATA_TYPE_KEY, REMOTE_METADATA_TYPE.get(settings)), + Map.entry(REMOTE_METADATA_ENDPOINT_KEY, REMOTE_METADATA_ENDPOINT.get(settings)), + Map.entry(REMOTE_METADATA_REGION_KEY, REMOTE_METADATA_REGION.get(settings)), + Map.entry(REMOTE_METADATA_SERVICE_NAME_KEY, REMOTE_METADATA_SERVICE_NAME.get(settings)), + Map.entry(TENANT_AWARE_KEY, "true"), + Map.entry(TENANT_ID_FIELD_KEY, TENANT_ID_FIELD) + ) + : Collections.emptyMap() + ); + mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); @@ -743,7 +776,7 @@ public List getRestHandlers( RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); - RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); + RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); @@ -976,7 +1009,12 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS, - MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE + MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE, + MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED, + MLCommonsSettings.REMOTE_METADATA_TYPE, + MLCommonsSettings.REMOTE_METADATA_ENDPOINT, + MLCommonsSettings.REMOTE_METADATA_REGION, + MLCommonsSettings.REMOTE_METADATA_SERVICE_NAME ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java index 0c1e124e4c..fe46959d75 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java @@ -9,14 +9,18 @@ import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import static org.opensearch.ml.utils.RestActionUtils.returnContent; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,10 +31,20 @@ public class RestMLGetConnectorAction extends BaseRestHandler { private static final String ML_GET_CONNECTOR_ACTION = "ml_get_connector_action"; + private ClusterService clusterService; + + private Settings settings; + + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + /** * Constructor */ - public RestMLGetConnectorAction() {} + public RestMLGetConnectorAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.clusterService = clusterService; + this.settings = settings; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -59,7 +73,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLConnectorGetRequest getRequest(RestRequest request) throws IOException { String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); boolean returnContent = returnContent(request); - - return new MLConnectorGetRequest(connectorId, returnContent); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + return new MLConnectorGetRequest(connectorId, tenantId, returnContent); } } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index e69ee87ca3..66b0f163f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -5,6 +5,11 @@ package org.opensearch.ml.settings; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY; + import java.util.List; import java.util.function.Function; @@ -263,4 +268,53 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_CONTROLLER_ENABLED = Setting .boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * Indicates whether multi-tenancy is enabled in ML Commons. + * + * This is a static setting that must be configured before starting OpenSearch. It can be set in the following ways, in priority order: + * + *
    + *
  1. As a command-line argument using the -E flag (this overrides other options): + *
    +     *       ./bin/opensearch -Eplugins.ml_commons.multi_tenancy_enabled=true
    +     *       
    + *
  2. + *
  3. As a system property using OPENSEARCH_JAVA_OPTS (this overrides opensearch.yml): + *
    +     *       export OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true"
    +     *       ./bin/opensearch
    +     *       
    + * Or inline when starting OpenSearch: + *
    +     *       OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true" ./bin/opensearch
    +     *       
    + *
  4. + *
  5. In the opensearch.yml configuration file: + *
    +     *       plugins.ml_commons.multi_tenancy_enabled: true
    +     *       
    + *
  6. + *
+ * + * After setting this option, a full cluster restart is required for the changes to take effect. + */ + public static final Setting ML_COMMONS_MULTI_TENANCY_ENABLED = Setting + .boolSetting("plugins.ml_commons.multi_tenancy_enabled", false, Setting.Property.NodeScope); + + /** This setting sets the remote metadata type */ + public static final Setting REMOTE_METADATA_TYPE = Setting + .simpleString("plugins.ml_commons." + REMOTE_METADATA_TYPE_KEY, Setting.Property.NodeScope, Setting.Property.Final); + + /** This setting sets the remote metadata endpoint */ + public static final Setting REMOTE_METADATA_ENDPOINT = Setting + .simpleString("plugins.flow_framework." + REMOTE_METADATA_ENDPOINT_KEY, Setting.Property.NodeScope, Setting.Property.Final); + + /** This setting sets the remote metadata region */ + public static final Setting REMOTE_METADATA_REGION = Setting + .simpleString("plugins.flow_framework." + REMOTE_METADATA_REGION_KEY, Setting.Property.NodeScope, Setting.Property.Final); + + /** This setting sets the remote metadata service name */ + public static final Setting REMOTE_METADATA_SERVICE_NAME = Setting + .simpleString("plugins.flow_framework." + REMOTE_METADATA_SERVICE_NAME_KEY, Setting.Property.NodeScope, Setting.Property.Final); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index 93159125de..e32e3af89f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -11,14 +11,20 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.settings.SettingsChangeListener; + +import com.google.common.annotations.VisibleForTesting; public class MLFeatureEnabledSetting { @@ -32,6 +38,11 @@ public class MLFeatureEnabledSetting { private volatile Boolean isBatchIngestionEnabled; private volatile Boolean isBatchInferenceEnabled; + // This is to identify if this node is in multi-tenancy or not. + private volatile Boolean isMultiTenancyEnabled; + + private final List listeners = new ArrayList<>(); + public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); @@ -40,6 +51,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings); isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings); + isMultiTenancyEnabled = ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings); clusterService .getClusterSettings() @@ -111,4 +123,23 @@ public Boolean isOfflineBatchIngestionEnabled() { public Boolean isOfflineBatchInferenceEnabled() { return isBatchInferenceEnabled; } + + /** + * Whether the multi-tenancy feature is enabled. If disabled, tenant id will be null. + * @return whether the multi tenancy feature is enabled. + */ + public boolean isMultiTenancyEnabled() { + return isMultiTenancyEnabled; + } + + public void addListener(SettingsChangeListener listener) { + listeners.add(listener); + } + + @VisibleForTesting + void notifyMultiTenancyListeners(boolean isEnabled) { + for (SettingsChangeListener listener : listeners) { + listener.onMultiTenancyEnabledChanged(isEnabled); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/TenantAwareHelper.java b/plugin/src/main/java/org/opensearch/ml/utils/TenantAwareHelper.java new file mode 100644 index 0000000000..12bdddb62f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/TenantAwareHelper.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils; + +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.input.Constants; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class TenantAwareHelper { + + /** + * Validates the tenant ID based on the multi-tenancy feature setting. + * + * @param mlFeatureEnabledSetting The settings that indicate whether the multi-tenancy feature is enabled. + * @param tenantId The tenant ID to validate. + * @param listener The action listener to handle failure cases. + * @return true if the tenant ID is valid or if multi-tenancy is not enabled; false if the tenant ID is invalid and multi-tenancy is enabled. + */ + public static boolean validateTenantId(MLFeatureEnabledSetting mlFeatureEnabledSetting, String tenantId, ActionListener listener) { + if (mlFeatureEnabledSetting.isMultiTenancyEnabled() && tenantId == null) { + listener.onFailure(new OpenSearchStatusException("You don't have permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else + return true; + } + + /** + * Validates the tenant resource by comparing the tenant ID from the request with the tenant ID from the resource. + * + * @param mlFeatureEnabledSetting The settings that indicate whether the multi-tenancy feature is enabled. + * @param tenantIdFromRequest The tenant ID obtained from the request. + * @param tenantIdFromResource The tenant ID obtained from the resource. + * @param listener The action listener to handle failure cases. + * @return true if the tenant IDs match or if multi-tenancy is not enabled; false if the tenant IDs do not match and multi-tenancy is enabled. + */ + public static boolean validateTenantResource( + MLFeatureEnabledSetting mlFeatureEnabledSetting, + String tenantIdFromRequest, + String tenantIdFromResource, + ActionListener listener + ) { + if (mlFeatureEnabledSetting.isMultiTenancyEnabled() && !Objects.equals(tenantIdFromRequest, tenantIdFromResource)) { + listener.onFailure(new OpenSearchStatusException("You don't have permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else + return true; + } + + public static boolean isTenantFilteringEnabled(SearchRequest searchRequest) { + SearchSourceBuilder searchSourceBuilder = searchRequest.source(); + if (searchSourceBuilder != null) { + QueryBuilder queryBuilder = searchSourceBuilder.query(); + if (queryBuilder instanceof TermQueryBuilder) { + TermQueryBuilder termQuery = (TermQueryBuilder) queryBuilder; + return TENANT_ID_FIELD.equals(termQuery.fieldName()); // Tenant filtering is enabled + } + } + return false; // Tenant filtering is not enabled + } + + public static String getTenantID(Boolean isMultiTenancyEnabled, RestRequest restRequest) { + if (!isMultiTenancyEnabled) { + return null; + } + + Map> headers = restRequest.getHeaders(); + if (headers == null) { + throw new OpenSearchStatusException("Rest request headers can't be null", RestStatus.FORBIDDEN); + } + + List tenantIdList = headers.get(Constants.TENANT_ID_HEADER); + if (tenantIdList == null || tenantIdList.isEmpty()) { + throw new OpenSearchStatusException("Tenant ID header is missing or has no value", RestStatus.FORBIDDEN); + } + + String tenantId = tenantIdList.get(0); + if (tenantId == null) { + throw new OpenSearchStatusException("Tenant ID can't be null", RestStatus.FORBIDDEN); + } + + return tenantId; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java index 7df3fff3d5..050cfa7104 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java @@ -12,9 +12,9 @@ import java.util.ArrayList; import java.util.List; -import org.apache.lucene.tests.util.LuceneTestCase; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.action.ActionFuture; @@ -69,10 +69,9 @@ public void setUp() throws Exception { irisIndexName = "iris_data_for_prediction_it"; loadIrisData(irisIndexName); - // TODO: open these lines when this bug fix merged https://github.com/oracle/tribuo/issues/223 - // modelId = trainKmeansWithIrisData(irisIndexName, false); - // MLModel kMeansModel = getModel(kMeansModelId); - // assertNotNull(kMeansModel); + kMeansModelId = trainKmeansWithIrisData(irisIndexName, false); + MLModel kMeansModel = getModel(kMeansModelId); + assertNotNull(kMeansModel); batchRcfModelId = trainBatchRCFWithDataFrame(500, false); fitRcfModelId = trainFitRCFWithDataFrame(500, false); @@ -82,18 +81,19 @@ public void setUp() throws Exception { assertNotNull(batchRcfModel); } - @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + @Test public void testPredictionWithSearchInput_KMeans() { MLInputDataset inputDataset = new SearchQueryInputDataset(ImmutableList.of(irisIndexName), irisDataQuery()); predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } - @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + @Test public void testPredictionWithDataInput_KMeans() { MLInputDataset inputDataset = new DataFrameInputDataset(irisDataFrame()); predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } + @Test public void testPredictionWithoutDataset_KMeans() { exceptionRule.expect(ActionRequestValidationException.class); exceptionRule.expectMessage("input data can't be null"); @@ -103,7 +103,7 @@ public void testPredictionWithoutDataset_KMeans() { predictionFuture.actionGet(); } - @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + @Test public void testPredictionWithEmptyDataset_KMeans() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("No document found"); @@ -114,6 +114,7 @@ public void testPredictionWithEmptyDataset_KMeans() { predictionFuture.actionGet(); } + @Test public void testPredictionWithSearchInput_LogisticRegression() { MLInputDataset inputDataset = new SearchQueryInputDataset( ImmutableList.of(irisIndexName), @@ -122,11 +123,13 @@ public void testPredictionWithSearchInput_LogisticRegression() { predictAndVerify(logisticRegressionModelId, inputDataset, FunctionName.LOGISTIC_REGRESSION, null, IRIS_DATA_SIZE); } + @Test public void testPredictionWithDataFrame_BatchRCF() { MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize)); predictAndVerify(batchRcfModelId, inputDataset, FunctionName.BATCH_RCF, null, batchRcfDataSize); } + @Test public void testPredictionWithDataFrame_FitRCF() { MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize, true)); DataFrame dataFrame = predictAndVerify( @@ -138,6 +141,7 @@ public void testPredictionWithDataFrame_FitRCF() { ); } + @Test public void testPredictionWithDataFrame_LinearRegression() { int size = 1; int feet = 20; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java index 6934f09c96..d195371a55 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java @@ -11,18 +11,24 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -30,6 +36,7 @@ import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -42,17 +49,30 @@ public class RestMLGetConnectorActionTests extends OpenSearchTestCase { @Rule public ExpectedException thrown = ExpectedException.none(); + @Mock + private ClusterService clusterService; + private RestMLGetConnectorAction restMLGetConnectorAction; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; + Settings settings; + @Mock RestChannel channel; @Before public void setup() { - restMLGetConnectorAction = new RestMLGetConnectorAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -72,8 +92,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetConnectorAction mlGetConnectorAction = new RestMLGetConnectorAction(); - assertNotNull(mlGetConnectorAction); + assertNotNull(restMLGetConnectorAction); } public void testGetName() { @@ -104,7 +123,6 @@ public void test_PrepareRequest() throws Exception { private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_CONNECTOR_ID, "connector_id"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); - return request; + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java b/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java new file mode 100644 index 0000000000..5608b7ef81 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.settings; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.settings.SettingsChangeListener; + +public class MLFeatureEnabledSettingTests { + @Mock + private ClusterService clusterService; + private Settings settings; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private SettingsChangeListener listener; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()) + .thenReturn( + new ClusterSettings( + settings, + Set + .of( + ML_COMMONS_MULTI_TENANCY_ENABLED, + ML_COMMONS_REMOTE_INFERENCE_ENABLED, + ML_COMMONS_AGENT_FRAMEWORK_ENABLED, + ML_COMMONS_LOCAL_MODEL_ENABLED, + ML_COMMONS_CONTROLLER_ENABLED, + ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, + ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, + ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED + ) + ) + ); + + mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings); + listener = mock(SettingsChangeListener.class); + } + + @Test + public void testAddListenerAndNotify() { + mlFeatureEnabledSetting.addListener(listener); + + // Simulate settings change + mlFeatureEnabledSetting.notifyMultiTenancyListeners(false); + + // Verify listener is notified + verify(listener, times(1)).onMultiTenancyEnabledChanged(false); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TenantAwareHelperTests.java b/plugin/src/test/java/org/opensearch/ml/utils/TenantAwareHelperTests.java new file mode 100644 index 0000000000..c5adfce824 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/TenantAwareHelperTests.java @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; +import static org.opensearch.ml.utils.TestHelper.xContentRegistry; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.input.Constants; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.rest.FakeRestRequest; + +public class TenantAwareHelperTests { + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private ActionListener actionListener; + + Settings settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), true).build(); + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void testValidateTenantId_MultiTenancyEnabled_TenantIdNull() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + boolean result = TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, null, actionListener); + assertFalse(result); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assert exception.status() == RestStatus.FORBIDDEN; + assert exception.getMessage().equals("You don't have permission to access this resource"); + } + + @Test + public void testValidateTenantId_MultiTenancyEnabled_TenantIdPresent() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + boolean result = TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, "_tenant_id", actionListener); + assertTrue(result); + } + + @Test + public void testValidateTenantId_MultiTenancyDisabled() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + + boolean result = TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, null, actionListener); + + assertTrue(result); + } + + @Test + public void testValidateTenantResource_MultiTenancyEnabled_TenantIdMismatch() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + boolean result = TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, null, "different_tenant_id", actionListener); + assertFalse(result); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assert exception.status() == RestStatus.FORBIDDEN; + assert exception.getMessage().equals("You don't have permission to access this resource"); + } + + @Test + public void testValidateTenantResource_MultiTenancyEnabled_TenantIdMatch() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + boolean result = TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, "_tenant_id", "_tenant_id", actionListener); + assertTrue(result); + } + + @Test + public void testValidateTenantResource_MultiTenancyDisabled() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + boolean result = TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, "_tenant_id", "different_tenant_id", actionListener); + assertTrue(result); + } + + @Test + public void testIsTenantFilteringEnabled_TenantFilteringEnabled() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(new TermQueryBuilder(TENANT_ID_FIELD, "123456")); + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + + boolean result = TenantAwareHelper.isTenantFilteringEnabled(searchRequest); + assertTrue(result); + } + + @Test + public void testIsTenantFilteringEnabled_TenantFilteringDisabled() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchAllQuery()); + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + + boolean result = TenantAwareHelper.isTenantFilteringEnabled(searchRequest); + assertFalse(result); + } + + @Test + public void testIsTenantFilteringEnabled_NoQuery() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + + boolean result = TenantAwareHelper.isTenantFilteringEnabled(searchRequest); + assertFalse(result); + } + + @Test + public void testIsTenantFilteringEnabled_NullSource() { + SearchRequest searchRequest = new SearchRequest(); + + boolean result = TenantAwareHelper.isTenantFilteringEnabled(searchRequest); + assertFalse(result); + } + + @Test + public void testGetTenantID_IndependentNode() { + String tenantId = "test-tenant"; + Map> headers = new HashMap<>(); + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build(); + + String actualTenantID = TenantAwareHelper.getTenantID(Boolean.TRUE, restRequest); + Assert.assertEquals(tenantId, actualTenantID); + } + + @Test + public void testGetTenantID_IndependentNode_NullTenantID() { + Map> headers = new HashMap<>(); + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(null)); + RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build(); + + try { + TenantAwareHelper.getTenantID(Boolean.TRUE, restRequest); + Assert.fail("Expected OpenSearchStatusException"); + } catch (Exception e) { + Assert.assertTrue(e instanceof OpenSearchStatusException); + Assert.assertEquals(RestStatus.FORBIDDEN, ((OpenSearchStatusException) e).status()); + Assert.assertEquals("Tenant ID can't be null", e.getMessage()); + } + } + + @Test + public void testGetTenantID_NotIndependentNode() { + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + String tenantId = "test-tenant"; + Map> headers = new HashMap<>(); + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build(); + + String tenantID = TenantAwareHelper.getTenantID(Boolean.FALSE, restRequest); + Assert.assertNull(tenantID); + } +}