Skip to content

Commit

Permalink
Adding extra check to ensure model gets undeployed and then gets deleted
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Jan 12, 2024
1 parent 494d972 commit f78109c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch;

import org.opensearch.ml.common.model.MLModelState;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
Expand Down Expand Up @@ -61,6 +62,8 @@
import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.TestUtils.MAX_RETRY;
import static org.opensearch.neuralsearch.TestUtils.MAX_TIME_OUT_INTERVAL;

import lombok.AllArgsConstructor;
import lombok.Getter;
Expand Down Expand Up @@ -664,8 +667,10 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

// after model undeploy returns, the max interval to update model status is 3s in ml-commons CronJob.
Thread.sleep(3000);
// wait for model undeploy to complete.
// Sometimes the undeploy action results in a DEPLOY_FAILED state. But this does not block the model from being deleted.
// So set both UNDEPLOYED and DEPLOY_FAILED as exit state.
pollForModelState(modelId, Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED));

makeRequest(
client(),
Expand All @@ -677,6 +682,46 @@ protected void deleteModel(String modelId) {
);
}

protected void pollForModelState(String modelId, Set<MLModelState> exitModelStates) throws InterruptedException {
MLModelState currentState = null;
for (int i = 0; i < MAX_RETRY; i++) {
Thread.sleep(MAX_TIME_OUT_INTERVAL);
currentState = getModelState(modelId);
if (exitModelStates.contains(currentState)) {
return;
}
}
fail(
String.format(
LOCALE,
"Model state does not reached exit states %s after %d attempts with interval of %d ms, latest model state: %s.",
StringUtils.join(exitModelStates, ","),
MAX_RETRY,
MAX_TIME_OUT_INTERVAL,
currentState
)
);
}

@SneakyThrows
protected MLModelState getModelState(String modelId) {
Response getModelResponse = makeRequest(
client(),
"GET",
String.format(LOCALE, "/_plugins/_ml/models/%s", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> getModelResponseJson = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(getModelResponse.getEntity()),
false
);
String modelState = (String) getModelResponseJson.get("model_state");
return MLModelState.valueOf(modelState);
}

public boolean isUpdateClusterSettings() {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public class TestUtils {
public static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
public static final String SPARSE_ENCODING_PROCESSOR = "sparse_encoding";
public static final int MAX_TIME_OUT_INTERVAL = 3000;
public static final int MAX_RETRY = 3;

/**
* Convert an xContentBuilder to a map
Expand Down

0 comments on commit f78109c

Please sign in to comment.