Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

system error handling #1893

Merged
merged 7 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
Expand All @@ -17,17 +19,23 @@
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.error.ErrorMessageFactory;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

Expand Down Expand Up @@ -62,7 +70,28 @@
@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request);
return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new RestToXContentListener<>(channel));

return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new ActionListener<>() {
@Override
public void onResponse(MLExecuteTaskResponse response) {
try {
sendResponse(channel, response);

Check warning on line 78 in plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

View check run for this annotation

Codecov / codecov/patch

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java#L78

Added line #L78 was not covered by tests
} catch (Exception e) {
reportError(channel, e, INTERNAL_SERVER_ERROR);
}

Check warning on line 81 in plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

View check run for this annotation

Codecov / codecov/patch

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java#L81

Added line #L81 was not covered by tests
}

@Override
public void onFailure(Exception e) {
RestStatus status;
if (isClientError(e)) {
status = BAD_REQUEST;
} else {
status = INTERNAL_SERVER_ERROR;
}
reportError(channel, e, status);
}
});
}

/**
Expand Down Expand Up @@ -95,4 +124,16 @@

return new MLExecuteTaskRequest(functionName, input);
}

private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) throws Exception {
channel.sendResponse(new RestToXContentListener<MLExecuteTaskResponse>(channel).buildResponse(response));
}

Check warning on line 130 in plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

View check run for this annotation

Codecov / codecov/patch

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java#L129-L130

Added lines #L129 - L130 were not covered by tests

private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, ErrorMessageFactory.createErrorMessage(e, status.getStatus()).toString()));
}

private boolean isClientError(Exception e) {
return e instanceof IllegalArgumentException || e instanceof IllegalAccessException;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.rest.RestStatus;

import com.fasterxml.jackson.databind.ObjectMapper;

import lombok.Getter;
import lombok.SneakyThrows;

/** Error Message. */
public class ErrorMessage {

protected Throwable exception;

private final int status;

@Getter
private final String type;

@Getter
private final String reason;

@Getter
private final String details;

/** Error Message Constructor. */
public ErrorMessage(Throwable exception, int status) {
this.exception = exception;
this.status = status;

this.type = fetchType();
this.reason = fetchReason();
this.details = fetchDetails();
}

private String fetchType() {
return exception.getClass().getSimpleName();
}

protected String fetchReason() {
return status == RestStatus.BAD_REQUEST.getStatus() ? "Invalid Request" : "System Error";
}

protected String fetchDetails() {
// Some exception prints internal information (full class name) which is security concern
return emptyStringIfNull(exception.getLocalizedMessage());
}

private String emptyStringIfNull(String str) {
return str != null ? str : "";
}

@SneakyThrows

Check warning on line 61 in plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessage.java

View check run for this annotation

Codecov / codecov/patch

plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessage.java#L61

Added line #L61 was not covered by tests
@Override
public String toString() {
ObjectMapper objectMapper = new ObjectMapper();
Map<String, Object> errorContent = new HashMap<>();
errorContent.put("type", type);
errorContent.put("reason", reason);
errorContent.put("details", details);
Map<String, Object> errMessage = new HashMap<>();
errMessage.put("status", status);
errMessage.put("error", errorContent);

return objectMapper.writeValueAsString(errMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import org.opensearch.OpenSearchException;

import lombok.experimental.UtilityClass;

@UtilityClass
public class ErrorMessageFactory {
/**
* Create error message based on the exception type.
*
* @param e exception to create error message
* @param status exception status code
* @return error message
*/
public static ErrorMessage createErrorMessage(Throwable e, int status) {
Throwable t = e;
int st = status;
if (t instanceof OpenSearchException) {
st = ((OpenSearchException) t).status().getStatus();
} else {
t = unwrapCause(e);
}

return new ErrorMessage(t, st);
}

protected static Throwable unwrapCause(Throwable t) {
Throwable result = t;
if (result instanceof OpenSearchException) {
return result;
}
if (result.getCause() == null) {
return result;
}
result = unwrapCause(result.getCause());
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -33,9 +36,11 @@
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -132,6 +137,69 @@ public void testPrepareRequest() throws Exception {
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest1() throws Exception {
doNothing().when(channel).sendResponse(isA(RestResponse.class));
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest2() throws Exception {
doThrow(new IllegalArgumentException("input error")).when(channel).sendResponse(isA(RestResponse.class));
doNothing().when(channel).sendResponse(isA(BytesRestResponse.class));
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequestClientError() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException("input error"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequestSystemError() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("system error"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest_disabled() {
RestRequest request = getExecuteAgentRestRequest();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.opensearch.OpenSearchException;
import org.opensearch.core.rest.RestStatus;

public class ErrorMessageFactoryTests {

private Throwable nonOpenSearchThrowable = new Throwable();
private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable);

@Test
public void openSearchExceptionShouldCreateEsErrorMessage() {
Exception exception = new OpenSearchException(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() {
Exception exception = new Exception(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertFalse(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = (Exception) openSearchThrowable;
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable)));
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}
}
Loading
Loading