diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java index a1ccff88f2..21217b09a0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java @@ -154,7 +154,15 @@ public void onResponse(SearchResponse searchResponse) { @Override public void onFailure(Exception e) { - onFailed(channel, "Searching model wasn't successful", e); + try { + builder.startObject(); + builder.endObject(); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + } catch (IOException ex) { + String errorMessage = "Failed to get ML node level profile"; + log.error(errorMessage, e); + onFailed(channel, errorMessage, e); + } } }, threadContext::restore)); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 6560362c48..105a7ff9a0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -176,7 +176,11 @@ public void onResponse(SearchResponse searchResponse) { @Override public void onFailure(Exception e) { - onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Searching model wasn't successful", e); + try { + getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel); + } catch (IOException ex) { + onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to retrieve Cluster level metrics", e); + } } }, threadContext::restore)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java index d205f5cbd8..446dc74213 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; @@ -51,6 +52,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.action.profile.MLProfileAction; @@ -67,6 +69,7 @@ import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.profile.MLPredictRequestStats; import org.opensearch.ml.profile.MLProfileInput; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -151,13 +154,6 @@ public void setup() throws IOException { testState = setupTestClusterState(); when(clusterService.state()).thenReturn(testState); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here - listener.onResponse(response); - return null; - }).when(client).search(any(SearchRequest.class), any()); - doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); Map nodeTasks = new HashMap<>(); @@ -207,6 +203,13 @@ public void testRoutes() { } public void test_PrepareRequest_TaskRequest() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + RestRequest request = getRestRequest(); profileAction.handleRequest(request, channel, client); @@ -218,6 +221,13 @@ public void test_PrepareRequest_TaskRequest() throws Exception { } public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/tasks").build(); profileAction.handleRequest(request, channel, client); @@ -228,6 +238,13 @@ public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception { } public void test_PrepareRequest_ModelRequest() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + RestRequest request = getModelRestRequest(); profileAction.handleRequest(request, channel, client); @@ -239,6 +256,13 @@ public void test_PrepareRequest_ModelRequest() throws Exception { } public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/models").build(); profileAction.handleRequest(request, channel, client); @@ -249,6 +273,12 @@ public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception { } public void test_PrepareRequest_EmptyNodeProfile() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); MLProfileResponse profileResponse = new MLProfileResponse(clusterName, new ArrayList<>(), new ArrayList<>()); @@ -267,6 +297,13 @@ public void test_PrepareRequest_EmptyNodeProfile() throws Exception { } public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); Map nodeTasks = new HashMap<>(); @@ -288,6 +325,13 @@ public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception { } public void test_PrepareRequest_WithRequestContent() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + MLProfileInput mlProfileInput = new MLProfileInput(); RestRequest request = getProfileRestRequest(mlProfileInput); profileAction.handleRequest(request, channel, client); @@ -296,6 +340,13 @@ public void test_PrepareRequest_WithRequestContent() throws Exception { } public void test_PrepareRequest_Failure() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new RuntimeException("test failure")); @@ -308,7 +359,40 @@ public void test_PrepareRequest_Failure() throws Exception { verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any()); } + public void test_Search_Failure() throws Exception { + // Setup to simulate a search failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("Mocking Exception")); // Trigger failure + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + // Create a RestRequest instance for testing + RestRequest request = getRestRequest(); // Ensure this method correctly initializes a RestRequest + + // Handle the request with the expectation of handling a failure + profileAction.handleRequest(request, channel, client); + + // Verification that the search method was called exactly once + verify(client, times(1)).search(any(SearchRequest.class), any(ActionListener.class)); + + // Capturing the response sent to the channel + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); + verify(channel).sendResponse(responseCaptor.capture()); + + // Check the response status code to see if it correctly reflects the error + BytesRestResponse response = responseCaptor.getValue(); + assertEquals(RestStatus.OK, response.status()); + assertTrue(response.content().utf8ToString().contains("{}")); + } + public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); MLProfileInput mlProfileInput = new MLProfileInput(); RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model")); profileAction.handleRequest(request, channel, client); @@ -316,6 +400,43 @@ public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception { verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any()); } + // public void testNodeViewOutput() throws Exception { + // // Assuming setup for non-empty node responses as done in the initial setup + // MLProfileInput mlProfileInput = new MLProfileInput(); + // RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "node")); + // profileAction.handleRequest(request, channel, client); + // + // ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class); + // verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any()); + // + // // Verify that the response is correctly formed for the node view + // verify(channel).sendResponse(argThat(response -> { + // // Ensure the response content matches expected node view structure + // String content = response.content().utf8ToString(); + // return content.contains("\"node\":") && !content.contains("\"models\":"); + // })); + // } + + public void testBackendFailureHandling() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here + listener.onResponse(response); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Simulated backend failure")); + return null; + }).when(client).execute(eq(MLProfileAction.INSTANCE), any(MLProfileRequest.class), any(ActionListener.class)); + + RestRequest request = getRestRequest(); + profileAction.handleRequest(request, channel, client); + + verify(channel).sendResponse(argThat(response -> response.status() == RestStatus.INTERNAL_SERVER_ERROR)); + } + private SearchResponse createSearchModelResponse() throws IOException { XContentBuilder content = builder(); content.startObject();