diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index d8136eeb19..ae31269800 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -5,6 +5,8 @@ package org.opensearch.ml.rest; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; @@ -17,6 +19,8 @@ import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; @@ -24,10 +28,14 @@ import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.error.ErrorMessageFactory; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -62,7 +70,28 @@ public List routes() { @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request); - return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new RestToXContentListener<>(channel)); + + return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new ActionListener<>() { + @Override + public void onResponse(MLExecuteTaskResponse response) { + try { + sendResponse(channel, response); + } catch (Exception e) { + reportError(channel, e, INTERNAL_SERVER_ERROR); + } + } + + @Override + public void onFailure(Exception e) { + RestStatus status; + if (isClientError(e)) { + status = BAD_REQUEST; + } else { + status = INTERNAL_SERVER_ERROR; + } + reportError(channel, e, status); + } + }); } /** @@ -95,4 +124,16 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException { return new MLExecuteTaskRequest(functionName, input); } + + private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) throws Exception { + channel.sendResponse(new RestToXContentListener(channel).buildResponse(response)); + } + + private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { + channel.sendResponse(new BytesRestResponse(status, ErrorMessageFactory.createErrorMessage(e, status.getStatus()).toString())); + } + + private boolean isClientError(Exception e) { + return e instanceof IllegalArgumentException || e instanceof IllegalAccessException; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessage.java b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessage.java new file mode 100644 index 0000000000..8184b0a835 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessage.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils.error; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.rest.RestStatus; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import lombok.Getter; +import lombok.SneakyThrows; + +/** Error Message. */ +public class ErrorMessage { + + protected Throwable exception; + + private final int status; + + @Getter + private final String type; + + @Getter + private final String reason; + + @Getter + private final String details; + + /** Error Message Constructor. */ + public ErrorMessage(Throwable exception, int status) { + this.exception = exception; + this.status = status; + + this.type = fetchType(); + this.reason = fetchReason(); + this.details = fetchDetails(); + } + + private String fetchType() { + return exception.getClass().getSimpleName(); + } + + protected String fetchReason() { + return status == RestStatus.BAD_REQUEST.getStatus() ? "Invalid Request" : "System Error"; + } + + protected String fetchDetails() { + // Some exception prints internal information (full class name) which is security concern + return emptyStringIfNull(exception.getLocalizedMessage()); + } + + private String emptyStringIfNull(String str) { + return str != null ? str : ""; + } + + @SneakyThrows + @Override + public String toString() { + ObjectMapper objectMapper = new ObjectMapper(); + Map errorContent = new HashMap<>(); + errorContent.put("type", type); + errorContent.put("reason", reason); + errorContent.put("details", details); + Map errMessage = new HashMap<>(); + errMessage.put("status", status); + errMessage.put("error", errorContent); + + return objectMapper.writeValueAsString(errMessage); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java new file mode 100644 index 0000000000..30aace4be3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils.error; + +import org.opensearch.OpenSearchException; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public class ErrorMessageFactory { + /** + * Create error message based on the exception type. + * + * @param e exception to create error message + * @param status exception status code + * @return error message + */ + public static ErrorMessage createErrorMessage(Throwable e, int status) { + Throwable t = e; + int st = status; + if (t instanceof OpenSearchException) { + st = ((OpenSearchException) t).status().getStatus(); + } else { + t = unwrapCause(e); + } + + return new ErrorMessage(t, st); + } + + protected static Throwable unwrapCause(Throwable t) { + Throwable result = t; + if (result instanceof OpenSearchException) { + return result; + } + if (result.getCause() == null) { + return result; + } + result = unwrapCause(result.getCause()); + return result; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java index a89b712170..597ae57a8a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java @@ -7,7 +7,10 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -33,9 +36,11 @@ import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -132,6 +137,69 @@ public void testPrepareRequest() throws Exception { assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); } + public void testPrepareRequest1() throws Exception { + doNothing().when(channel).sendResponse(isA(RestResponse.class)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null)); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + } + + public void testPrepareRequest2() throws Exception { + doThrow(new IllegalArgumentException("input error")).when(channel).sendResponse(isA(RestResponse.class)); + doNothing().when(channel).sendResponse(isA(BytesRestResponse.class)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null)); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + } + + public void testPrepareRequestClientError() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new IllegalArgumentException("input error")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + } + + public void testPrepareRequestSystemError() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("system error")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + } + public void testPrepareRequest_disabled() { RestRequest request = getExecuteAgentRestRequest(); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java new file mode 100644 index 0000000000..00f3da1b01 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils.error; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.opensearch.OpenSearchException; +import org.opensearch.core.rest.RestStatus; + +public class ErrorMessageFactoryTests { + + private Throwable nonOpenSearchThrowable = new Throwable(); + private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable); + + @Test + public void openSearchExceptionShouldCreateEsErrorMessage() { + Exception exception = new OpenSearchException(nonOpenSearchThrowable); + ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); + assertTrue(msg.exception instanceof OpenSearchException); + } + + @Test + public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() { + Exception exception = new Exception(nonOpenSearchThrowable); + ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); + assertFalse(msg.exception instanceof OpenSearchException); + } + + @Test + public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() { + Exception exception = (Exception) openSearchThrowable; + ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); + assertTrue(msg.exception instanceof OpenSearchException); + } + + @Test + public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() { + Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable))); + ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); + assertTrue(msg.exception instanceof OpenSearchException); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageTests.java b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageTests.java new file mode 100644 index 0000000000..f4516ad0c1 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageTests.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils.error; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; + +import org.junit.Test; + +public class ErrorMessageTests { + + @Test + public void fetchReason() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + + assertEquals(errorMessage.fetchReason(), "System Error"); + } + + @Test + public void fetchDetails() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + + assertEquals(errorMessage.fetchDetails(), "illegal state"); + } + + @Test + public void testToString() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + assertEquals( + "{\"error\":{\"reason\":\"System Error\",\"details\":\"illegal state\",\"type\":\"IllegalStateException\"},\"status\":503}", + errorMessage.toString() + ); + } + + @Test + public void testBadRequestToString() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException(), BAD_REQUEST.getStatus()); + assertEquals( + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"\",\"type\":\"IllegalStateException\"},\"status\":400}", + errorMessage.toString() + ); + } + + @Test + public void testToStringWithEmptyErrorMessage() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException(), SERVICE_UNAVAILABLE.getStatus()); + assertEquals( + "{\"error\":{\"reason\":\"System Error\",\"details\":\"\",\"type\":\"IllegalStateException\"},\"status\":503}", + errorMessage.toString() + ); + } + + @Test + public void getType() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + + assertEquals(errorMessage.getType(), "IllegalStateException"); + } + + @Test + public void getReason() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + + assertEquals(errorMessage.getReason(), "System Error"); + } + + @Test + public void getDetails() { + ErrorMessage errorMessage = new ErrorMessage(new IllegalStateException("illegal state"), SERVICE_UNAVAILABLE.getStatus()); + + assertEquals(errorMessage.getDetails(), "illegal state"); + } +}