From 2a4bacd2923268c32285975b3d4168c6367b7458 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Tue, 7 May 2024 13:18:40 +0800 Subject: [PATCH] add wait for model to be undeployed Signed-off-by: Hailong Cui --- .../ml/tools/ToolIntegrationWithLLMTest.java | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index 02b3a2450a..50dd17f589 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -8,15 +8,19 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.rest.RestBaseAgentToolsIT; import org.opensearch.ml.utils.TestHelper; @@ -26,6 +30,10 @@ @Log4j2 public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT { + + private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + protected HttpServer server; protected String modelId; protected String agentId; @@ -66,9 +74,42 @@ public void stopMockLLM() { public void deleteModel() throws IOException { log.info("deleteModel"); undeployModel(modelId); + waitModelUndeployed(modelId); deleteModel(client(), modelId, null); } + @SneakyThrows + private void waitModelUndeployed(String modelId) { + Predicate> condition = responseInMap -> { + String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString(); + return !state.equals(MLModelState.DEPLOYED.toString()) + && !state.equals(MLModelState.DEPLOYING.toString()) + && !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString()); + }; + waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition); + } + + @SneakyThrows + protected Map waitResponseMeetingCondition( + String method, + String endpoint, + String jsonEntity, + Predicate> condition + ) { + for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) { + Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Map responseInMap = parseResponseToMap(response); + if (condition.test(responseInMap)) { + return responseInMap; + } + logger.info("The {}-th response: {}", i, responseInMap.toString()); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds."); + return null; + } + private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException { int retryTimes = 0; String connectorId = null;