diff --git a/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java b/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java index ac38ce24c6..509d9f9dc2 100644 --- a/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java +++ b/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java @@ -12,7 +12,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; @@ -31,7 +30,6 @@ import org.opensearch.core.common.Strings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; @@ -186,6 +184,9 @@ private void triggerAutoDeployModels(List addedNodes) { modelAutoRedeployArrangements.add(modelAutoRedeployArrangement); }); redeployAModel(); + } else { + log.info("Could not find any models in the index, not performing auto reloading!"); + startCronjobAndClearListener(); } }, e -> { if (e instanceof IndexNotFoundException) { @@ -241,9 +242,7 @@ private void queryRunningModels(ActionListener listener) { String[] includes = new String[] { MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, MLModel.PLANNING_WORKER_NODES_FIELD, - MLModel.DEPLOY_TO_ALL_NODES_FIELD, - MLModel.FUNCTION_NAME_FIELD, - MLModel.ALGORITHM_FIELD }; + MLModel.DEPLOY_TO_ALL_NODES_FIELD }; String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; FetchSourceContext fetchContext = new FetchSourceContext(true, includes, excludes); @@ -261,29 +260,22 @@ private void queryRunningModels(ActionListener listener) { private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeployArrangement) { if (modelAutoRedeployArrangement == null) { log.info("No more models in arrangement, skipping the redeployment"); + startCronjobAndClearListener(); return; } String modelId = modelAutoRedeployArrangement.getSearchResponse().getId(); List addedNodes = modelAutoRedeployArrangement.getAddedNodes(); - Map sourceAsMap = modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap(); - String functionName = (String) Optional - .ofNullable(sourceAsMap.get(MLModel.FUNCTION_NAME_FIELD)) - .orElse(sourceAsMap.get(MLModel.ALGORITHM_FIELD)); - if (functionName == null) { - log - .error( - "Model function_name or algorithm is null, model is not in correct status, please check the model, model id is: {}", - modelId - ); - return; - } - if (FunctionName.REMOTE == FunctionName.from(functionName)) { - log.info("Skipping redeploying remote model {} as remote model deployment can be done at prediction time.", modelId); - return; - } - List planningWorkerNodes = (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD); - Integer autoRedeployRetryTimes = (Integer) sourceAsMap.get(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD); - Boolean deployToAllNodes = (Boolean) Optional.ofNullable(sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)).orElse(false); + List planningWorkerNodes = (List) modelAutoRedeployArrangement + .getSearchResponse() + .getSourceAsMap() + .get(MLModel.PLANNING_WORKER_NODES_FIELD); + Integer autoRedeployRetryTimes = (Integer) modelAutoRedeployArrangement + .getSearchResponse() + .getSourceAsMap() + .get(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD); + Boolean deployToAllNodes = (Boolean) Optional + .ofNullable(modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap().get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)) + .orElse(false); // calculate node ids. String[] nodeIds = null; if (deployToAllNodes || !allowCustomDeploymentPlan) { @@ -302,6 +294,7 @@ private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeploy .info( "Allow custom deployment plan is true and deploy to all nodes is false and added nodes are not in planning worker nodes list, not to auto redeploy the model to the new nodes!" ); + redeployAModel(); return; } diff --git a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java index a0acf8831f..6f8d4e453d 100644 --- a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java @@ -609,34 +609,6 @@ public void test_redeployAModel_with_needRedeployArray_isEmpty() { mlModelAutoReDeployer.redeployAModel(); } - public void test_buildAutoReloadArrangement_skippingRemoteModel_success() throws Exception { - Settings settings = Settings - .builder() - .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true) - .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES.getKey(), 3) - .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true) - .put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), false) - .build(); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.localNode()).thenReturn(localNode); - when(clusterService.getClusterSettings()).thenReturn(getClusterSettings(settings)); - mockClusterDataNodes(clusterService); - - mlModelAutoReDeployer = spy( - new MLModelAutoReDeployer(clusterService, client, settings, mlModelManager, searchRequestBuilderFactory) - ); - - SearchResponse searchResponse = buildDeployToAllNodesTrueSearchResponse("RemoteModelResult.json"); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(searchResponse); - return null; - }).when(searchRequestBuilder).execute(isA(ActionListener.class)); - mlModelAutoReDeployer.buildAutoReloadArrangement(addedNodes, clusterManagerNodeId); - verify(client, never()).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), any(ActionListener.class)); - } - private SearchResponse buildDeployToAllNodesTrueSearchResponse(String file) throws Exception { MLModel mlModel = buildModelWithJsonFile(file); return createResponseWithModel(mlModel); diff --git a/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json b/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json deleted file mode 100644 index fe7103fdee..0000000000 --- a/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "last_deployed_time": 1722954415807, - "model_version": "619", - "created_time": 1722954415642, - "deploy_to_all_nodes": true, - "is_hidden": false, - "description": "This is a test model", - "model_state": "DEPLOYED", - "planning_worker_node_count": 1, - "auto_redeploy_retry_times": 0, - "last_updated_time": 1723691017054, - "name": "my sagemaker model", - "connector_id": "z3kVKJEBAfFjoGUT_Ui7", - "current_worker_node_count": 0, - "model_group_id": "MiJPJ5EBQM-QzppeWrTJ", - "planning_worker_nodes": [ - "DecGG5pDQYaqelLMLcIV9Q" - ], - "algorithm": "REMOTE" -} diff --git a/plugin/src/test/resources/org/opensearch/ml/autoredeploy/TracedSmallModelRequest.json b/plugin/src/test/resources/org/opensearch/ml/autoredeploy/TracedSmallModelRequest.json index 9fc53f3b91..0173665a5d 100644 --- a/plugin/src/test/resources/org/opensearch/ml/autoredeploy/TracedSmallModelRequest.json +++ b/plugin/src/test/resources/org/opensearch/ml/autoredeploy/TracedSmallModelRequest.json @@ -11,4 +11,4 @@ "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" }, "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true" -} +} \ No newline at end of file