Skip to content

Commit

Permalink
adding check to transport layer
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Mar 19, 2024
1 parent 317fab6 commit 363429b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.RestActionUtils;
Expand Down Expand Up @@ -58,6 +60,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action

private volatile boolean enableAutomaticDeployment;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportPredictionTaskAction(
TransportService transportService,
Expand All @@ -69,6 +73,7 @@ public TransportPredictionTaskAction(
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
Expand All @@ -80,6 +85,7 @@ public TransportPredictionTaskAction(
this.xContentRegistry = xContentRegistry;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
Expand Down Expand Up @@ -107,6 +113,9 @@ public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
Expand Down Expand Up @@ -56,6 +57,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -94,6 +96,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportRegisterModelAction(
Expand All @@ -112,7 +115,8 @@ public TransportRegisterModelAction(
MLStats mlStats,
ModelAccessControlHelper modelAccessControlHelper,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelGroupManager mlModelGroupManager
MLModelGroupManager mlModelGroupManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLRegisterModelAction.NAME, transportService, actionFilters, MLRegisterModelRequest::new);
this.transportService = transportService;
Expand All @@ -129,6 +133,7 @@ public TransportRegisterModelAction(
this.modelAccessControlHelper = modelAccessControlHelper;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelGroupManager = mlModelGroupManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
this.settings = settings;

trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
Expand All @@ -147,6 +152,9 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -48,6 +49,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -83,6 +85,8 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase {

@Mock
private ModelAccessControlHelper modelAccessControlHelper;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Mock
ActionFilters actionFilters;
Expand Down Expand Up @@ -129,6 +133,7 @@ public void setup() {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);

transportPredictionTaskAction = spy(
new TransportPredictionTaskAction(
Expand All @@ -141,6 +146,7 @@ public void setup() {
xContentRegistry,
mlModelManager,
modelAccessControlHelper,
mlFeatureEnabledSetting,
settings
)
);
Expand Down Expand Up @@ -168,6 +174,21 @@ public void testPrediction_default_exception() {
assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage());
}

public void testPrediction_local_model_not_exception() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);

IllegalStateException e = assertThrows(
IllegalStateException.class,
() -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener)
);
assertEquals(
e.getMessage(),
LOCAL_MODEL_DISABLED_ERR_MSG
);
}

public void testPrediction_OpenSearchStatusException() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
Expand Down Expand Up @@ -60,6 +61,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
Expand Down Expand Up @@ -150,6 +152,9 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase {
@Mock
private ConnectorAccessControlHelper connectorAccessControlHelper;

@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
Expand Down Expand Up @@ -184,7 +189,8 @@ public void setup() throws IOException {
mlStats,
modelAccessControlHelper,
connectorAccessControlHelper,
mlModelGroupManager
mlModelGroupManager,
mlFeatureEnabledSetting
);
assertNotNull(transportRegisterModelAction);

Expand Down Expand Up @@ -216,6 +222,8 @@ public void setup() throws IOException {
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);

when(clusterService.localNode()).thenReturn(node2);
when(node2.getId()).thenReturn("node2Id");

Expand All @@ -226,6 +234,42 @@ public void setup() throws IOException {
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDoExecute_LocalModelDisabledException() {
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
.functionName(FunctionName.TEXT_EMBEDDING)
.deployModel(true)
.modelGroupId("modelGroupID")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig(
"CUSTOM",
123,
TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.url("http://test_url")
.build();

MLRegisterModelRequest mlRegisterModelRequest = new MLRegisterModelRequest(registerModelInput);

IllegalStateException e = assertThrows(
IllegalStateException.class,
() -> transportRegisterModelAction.doExecute(task, mlRegisterModelRequest, actionListener)
);
assertEquals(
e.getMessage(),
LOCAL_MODEL_DISABLED_ERR_MSG
);
}

public void testDoExecute_userHasNoAccessException() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
Expand Down Expand Up @@ -328,7 +372,8 @@ public void testRegisterModelUrlNotAllowed() throws Exception {
mlStats,
modelAccessControlHelper,
connectorAccessControlHelper,
mlModelGroupManager
mlModelGroupManager,
mlFeatureEnabledSetting
);

IllegalArgumentException e = assertThrows(
Expand Down

0 comments on commit 363429b

Please sign in to comment.