diff --git a/CHANGELOG.md b/CHANGELOG.md index d10dae25d..905d4a8ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) - Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041)) +- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) - Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040)) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 9b35aa68e..b96acdd9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -167,14 +167,14 @@ private void retryableInferenceSentencesWithMapResult( mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> result = buildMapResultFromResponse(mlOutput); listener.onResponse(result); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + }, + e -> RetryUtil.handleRetryOrFailure( + e, + retryTime, + () -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener), + listener + ) + )); } private void retryableInferenceSentencesWithVectorResult( @@ -188,14 +188,14 @@ private void retryableInferenceSentencesWithVectorResult( mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + }, + e -> RetryUtil.handleRetryOrFailure( + e, + retryTime, + () -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener), + listener + ) + )); } private void retryableInferenceSimilarityWithVectorResult( @@ -209,13 +209,14 @@ private void retryableInferenceSimilarityWithVectorResult( mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); listener.onResponse(scores); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener); - } else { - listener.onFailure(e); - } - })); + }, + e -> RetryUtil.handleRetryOrFailure( + e, + retryTime, + () -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener), + listener + ) + )); } private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { @@ -291,14 +292,20 @@ private void retryableInferenceSentencesWithSingleVectorResult( final List vector = buildSingleVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence is : {} ", vector); listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + }, + e -> RetryUtil.handleRetryOrFailure( + e, + retryTime, + () -> retryableInferenceSentencesWithSingleVectorResult( + targetResponseFilters, + modelId, + inputObjects, + retryTime + 1, + listener + ), + listener + ) + )); } private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { diff --git a/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java index d638fb9c8..c4bd45280 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java @@ -6,30 +6,60 @@ import java.util.List; +import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.opensearch.core.action.ActionListener; import org.opensearch.transport.NodeDisconnectedException; import org.opensearch.transport.NodeNotConnectedException; import com.google.common.collect.ImmutableList; +import org.opensearch.common.Randomness; +@Log4j2 public class RetryUtil { - private static final int MAX_RETRY = 3; - + private static final int DEFAULT_MAX_RETRY = 3; + private static final long DEFAULT_BASE_DELAY_MS = 500; private static final List> RETRYABLE_EXCEPTIONS = ImmutableList.of( NodeNotConnectedException.class, NodeDisconnectedException.class ); /** - * - * @param e {@link Exception} which is the exception received to check if retryable. - * @param retryTime {@link int} which is the current retried times. - * @return {@link boolean} which is the result of if current exception needs retry or not. + * Handle retry or failure based on the exception and retry time + * @param e Exception + * @param retryTime Retry time + * @param retryAction Action to retry + * @param listener Listener to handle success or failure */ - public static boolean shouldRetry(final Exception e, int retryTime) { - boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1); - return hasRetryException && retryTime < MAX_RETRY; + public static void handleRetryOrFailure(Exception e, int retryTime, Runnable retryAction, ActionListener listener) { + if (shouldRetry(e, retryTime)) { + long backoffTime = calculateBackoffTime(retryTime); + log.warn("Retrying connection for ML inference due to [{}] after [{}ms]", e.getMessage(), backoffTime, e); + try { + Thread.sleep(backoffTime); + } catch (InterruptedException interruptedException) { + Thread.currentThread().interrupt(); + listener.onFailure(interruptedException); + return; + } + retryAction.run(); + } else { + listener.onFailure(e); + } + } + + private static boolean shouldRetry(final Exception e, int retryTime) { + return isRetryableException(e) && retryTime < DEFAULT_MAX_RETRY; } + private static boolean isRetryableException(final Exception e) { + return RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1); + } + + private static long calculateBackoffTime(int retryTime) { + long backoffTime = DEFAULT_BASE_DELAY_MS * (1L << retryTime); // Exponential backoff + long jitter = 10 + (long) (Randomness.get().nextDouble() * (50 - 10)); // Add jitter between 10ms and 50ms + return backoffTime + jitter; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 6d31ea6a6..18129bfec 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -124,7 +124,37 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(resultListener); } - public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() { + public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + // Verify client.predict is called 4 times (1 initial + 3 retries) + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + // Verify failure is propagated to the listener after all retries + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + + // Ensure no additional interactions with the listener + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -297,18 +327,28 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } - public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { - final RuntimeException exception = new RuntimeException(); + public void testInferenceMultimodal_whenExceptionFromMLClient_thenRetry_thenFailure() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); - actionListener.onFailure(exception); + actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); - Mockito.verify(client) + // Verify client.predict is called 4 times (1 initial + 3 retries) + Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onFailure(exception); + + // Verify failure is propagated to the listener after retries + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + + // Verify no further interactions with the listener Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } @@ -430,29 +470,6 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } - public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() { - final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( - mock(DiscoveryNode.class), - "Node not connected" - ); - Mockito.doAnswer(invocation -> { - final ActionListener actionListener = invocation.getArgument(2); - actionListener.onFailure(nodeNodeConnectedException); - return null; - }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - - accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), - singleSentenceResultListener - ); - - Mockito.verify(client, times(4)) - .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); - } - private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>();