diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 297a592d91..7cfc76ecdc 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -464,7 +464,7 @@ public void registerMLRemoteModel( mlRegisterModelInput.getTenantId(), new MLResourceNotFoundException("Failed to get model group due to index missing") ); - listener.onFailure(e); + listener.onFailure(new OpenSearchStatusException("Model group not found", RestStatus.NOT_FOUND)); } else { log.error("Failed to get model group", e); handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index b807363d4d..7f52fbf636 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -75,6 +75,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; @@ -92,9 +93,11 @@ import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; @@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean()); } + @Test + public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException { + // Create listener and capture the failure + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + ActionListener listener = mock(ActionListener.class); + + // Setup mocks + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); + + // Create test inputs + MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); + MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); + + // Mock index handler + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + + // Mock client.get() to throw IndexNotFoundException + doAnswer(invocation -> { + ActionListener getModelGroupListener = invocation.getArgument(1); + getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test")); + return null; + }).when(client).get(any(), any()); + + // Execute method under test + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener); + + // Verify the listener's onFailure was called with correct exception + verify(listener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + + // Verify exception type and message + assertTrue(exception instanceof OpenSearchStatusException); + assertEquals("Model group not found", exception.getMessage()); + assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status()); + } + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException { ActionListener listener = mock(ActionListener.class); doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());