Skip to content

Commit

Permalink
Optimize ML inference connection retry logic
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jan 4, 2025
1 parent c27fa94 commit 6e9e07d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 69 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ private void retryableInferenceSentencesWithMapResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSentencesWithVectorResult(
Expand All @@ -188,14 +188,14 @@ private void retryableInferenceSentencesWithVectorResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSimilarityWithVectorResult(
Expand All @@ -209,13 +209,14 @@ private void retryableInferenceSimilarityWithVectorResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener),
listener
)
));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
Expand Down Expand Up @@ -291,14 +292,20 @@ private void retryableInferenceSentencesWithSingleVectorResult(
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithSingleVectorResult(
targetResponseFilters,
modelId,
inputObjects,
retryTime + 1,
listener
),
listener
)
));
}

private MLInput createMLMultimodalInput(final List<String> targetResponseFilters, final Map<String, String> input) {
Expand Down
48 changes: 39 additions & 9 deletions src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,60 @@

import java.util.List;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.transport.NodeDisconnectedException;
import org.opensearch.transport.NodeNotConnectedException;

import com.google.common.collect.ImmutableList;
import org.opensearch.common.Randomness;

@Log4j2
public class RetryUtil {

private static final int MAX_RETRY = 3;

private static final int DEFAULT_MAX_RETRY = 3;
private static final long DEFAULT_BASE_DELAY_MS = 500;
private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = ImmutableList.of(
NodeNotConnectedException.class,
NodeDisconnectedException.class
);

/**
*
* @param e {@link Exception} which is the exception received to check if retryable.
* @param retryTime {@link int} which is the current retried times.
* @return {@link boolean} which is the result of if current exception needs retry or not.
* Handle retry or failure based on the exception and retry time
* @param e Exception
* @param retryTime Retry time
* @param retryAction Action to retry
* @param listener Listener to handle success or failure
*/
public static boolean shouldRetry(final Exception e, int retryTime) {
boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
return hasRetryException && retryTime < MAX_RETRY;
public static void handleRetryOrFailure(Exception e, int retryTime, Runnable retryAction, ActionListener<?> listener) {
if (shouldRetry(e, retryTime)) {
long backoffTime = calculateBackoffTime(retryTime);
log.warn("Retrying connection for ML inference due to [{}] after [{}ms]", e.getMessage(), backoffTime, e);
try {
Thread.sleep(backoffTime);
} catch (InterruptedException interruptedException) {
Thread.currentThread().interrupt();
listener.onFailure(interruptedException);
return;
}
retryAction.run();
} else {
listener.onFailure(e);
}
}

private static boolean shouldRetry(final Exception e, int retryTime) {
return isRetryableException(e) && retryTime < DEFAULT_MAX_RETRY;
}

private static boolean isRetryableException(final Exception e) {
return RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
}

private static long calculateBackoffTime(int retryTime) {
long backoffTime = DEFAULT_BASE_DELAY_MS * (1L << retryTime); // Exponential backoff
long jitter = 10 + (long) (Randomness.get().nextDouble() * (50 - 10)); // Add jitter between 10ms and 50ms
return backoffTime + jitter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,37 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() {
public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

// Verify failure is propagated to the listener after all retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);

// Ensure no additional interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
Expand Down Expand Up @@ -297,18 +327,28 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() {
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() {
final RuntimeException exception = new RuntimeException();
public void testInferenceMultimodal_whenExceptionFromMLClient_thenRetry_thenFailure() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(exception);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);

// Verify failure is propagated to the listener after retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);

// Verify no further interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

Expand Down Expand Up @@ -430,29 +470,6 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down

0 comments on commit 6e9e07d

Please sign in to comment.