Skip to content

Commit

Permalink
gracefully handles model group index not found exception
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Feb 1, 2025
1 parent d7dec0f commit 44f26a7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ public void registerMLRemoteModel(
mlRegisterModelInput.getTenantId(),
new MLResourceNotFoundException("Failed to get model group due to index missing")
);
listener.onFailure(e);
listener.onFailure(new OpenSearchStatusException("Model group not found", RestStatus.NOT_FOUND));
} else {
log.error("Failed to get model group", e);
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
Expand All @@ -92,9 +93,11 @@
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.get.GetResult;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
Expand Down Expand Up @@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
}

@Test
public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException {
// Create listener and capture the failure
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);

// Setup mocks
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);

// Create test inputs
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();

// Mock index handler
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);

// Mock client.get() to throw IndexNotFoundException
doAnswer(invocation -> {
ActionListener<GetResponse> getModelGroupListener = invocation.getArgument(1);
getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test"));
return null;
}).when(client).get(any(), any());

// Execute method under test
modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener);

// Verify the listener's onFailure was called with correct exception
verify(listener).onFailure(exceptionCaptor.capture());
Exception exception = exceptionCaptor.getValue();

// Verify exception type and message
assertTrue(exception instanceof OpenSearchStatusException);
assertEquals("Model group not found", exception.getMessage());
assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status());
}

public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down

0 comments on commit 44f26a7

Please sign in to comment.