From ff7fc7b6a23cb58a611e8452184863b638def361 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Tue, 7 May 2024 13:18:40 +0800 Subject: [PATCH] add wait for model to be undeployed Signed-off-by: Hailong Cui --- .../ml/tools/ToolIntegrationWithLLMTest.java | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index 02b3a2450a..f599c6fe05 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -8,15 +8,19 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.rest.RestBaseAgentToolsIT; import org.opensearch.ml.utils.TestHelper; @@ -26,6 +30,10 @@ @Log4j2 public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT { + + private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + protected HttpServer server; protected String modelId; protected String agentId; @@ -66,9 +74,46 @@ public void stopMockLLM() { public void deleteModel() throws IOException { log.info("deleteModel"); undeployModel(modelId); + waitModelUndeployed(modelId); deleteModel(client(), modelId, null); } + @SneakyThrows + private void waitModelUndeployed(String modelId) { + Predicate condition = response -> { + try { + Map responseInMap = parseResponseToMap(response); + String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString(); + return !state.equals(MLModelState.DEPLOYED.toString()) + && !state.equals(MLModelState.DEPLOYING.toString()) + && !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString()); + } catch (IOException e) { + return false; + } + }; + waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition); + } + + @SneakyThrows + protected Response waitResponseMeetingCondition( + String method, + String endpoint, + String jsonEntity, + Predicate condition + ) { + for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) { + Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + if (condition.test(response)) { + return response; + } + logger.info("The {}-th response: {}", i, response.toString()); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds."); + return null; + } + private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException { int retryTimes = 0; String connectorId = null;