Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Backport 2.x) Fixes Two Flaky IT classes RestMLGuardrailsIT & ToolIntegrationWithLLMTest #3263

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
}
return taskDone.get();
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
assertTrue(taskDone.get());
assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get());
}

public String registerConnector(String createConnectorInput) throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
Expand All @@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");

if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
Expand All @@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("guardrails triggered for user input");
Response response = createConnector(completionModelConnectorEntity);

Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");

waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
predictRemoteModel(modelId, predictInput);
Expand All @@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
predictRemoteModel(modelId, predictInput);
}
Expand All @@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
Map responseMap = parseResponseToMap(response);
String guardrailConnectorId = (String) responseMap.get("connector_id");

response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String guardrailModelId = (String) responseMap.get("model_id");

response = deployRemoteModel(guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Check the response from guardrails model that should be "accept".
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
response = predictRemoteModel(guardrailModelId, predictInput);
Expand All @@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
responseMap = (Map) responseMap.get("dataAsMap");
String validationResult = (String) responseMap.get("response");
Assert.assertTrue(validateRegex(validationResult, acceptRegex));

// Create predict model.
response = createConnector(completionModelConnectorEntity);
responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Predict.
predictInput = "{\n"
+ " \"parameters\": {\n"
Expand Down Expand Up @@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
Map responseMap = parseResponseToMap(response);
String guardrailConnectorId = (String) responseMap.get("connector_id");

// Create the model ID
response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String guardrailModelId = (String) responseMap.get("model_id");

response = deployRemoteModel(guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Check the response from guardrails model that should be "reject".
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}";
response = predictRemoteModel(guardrailModelId, predictInput);
Expand All @@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
responseMap = (Map) responseMap.get("dataAsMap");
String validationResult = (String) responseMap.get("response");
Assert.assertTrue(validateRegex(validationResult, rejectRegex));

// Create predict model.
response = createConnector(completionModelConnectorEntity);
responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,31 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
import org.opensearch.ml.utils.TestHelper;

import com.sun.net.httpserver.HttpServer;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT {

private int MAX_RETRIES;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

protected HttpServer server;
protected String modelId;
protected String agentId;
Expand All @@ -47,6 +55,7 @@ public void setupTestChatAgent() throws IOException, InterruptedException {
TimeUnit.SECONDS.sleep(1);
setupConversationalAgent(modelId);
log.info("model_id: {}, agent_id: {}", modelId, agentId);
MAX_RETRIES = this.getClusterHosts().size();
}

@After
Expand All @@ -62,9 +71,48 @@ public void stopMockLLM() {
@After
public void deleteModel() throws IOException {
undeployModel(modelId);
checkForModelUndeployedStatus(modelId);
deleteModel(client(), modelId, null);
}

@SneakyThrows
private void checkForModelUndeployedStatus(String modelId) {
Predicate<Response> condition = response -> {
try {
Map<String, Object> responseInMap = parseResponseToMap(response);
MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString());
return MLModelState.UNDEPLOYED.equals(state);
} catch (Exception e) {
return false;
}
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition);
}

@SneakyThrows
protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate<Response> condition) {
for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) {
Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
if (condition.test(response)) {
return response;
}
logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail(
String
.format(
Locale.ROOT,
"The response failed to meet condition after %d attempts. Attempted to perform %s : %s",
MAX_RETRIES,
method,
endpoint
)
);
return null;
}

private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException {
int retryTimes = 0;
String connectorId = null;
Expand Down
Loading