-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change httpclient to async #1958
Changes from all commits
bccfd16
425a450
0600d63
864e4bb
2fe466b
fa8d3ea
b761a5d
02f9602
3ee19ce
136830c
2941daf
6b54018
0bd49aa
5ecf3f2
a29b851
6ff57d9
f0ef57e
caf05a0
cadb989
d8a1281
c7e0588
e59e68c
42d08f8
f100187
e221284
3f0a252
802c537
3c76c22
7df33cd
ffd8565
3199a28
3eb3b7f
80f4c1f
7de1ab4
0ee363a
05a3111
b54c6b1
6a6287e
e3e98a1
8ffe6d7
bebb913
25d6845
031ab19
421b00b
71f3942
20a6172
80b91ed
e446f49
e138d83
f0f3a5b
b3aed63
d12fb86
5e6a7f6
a7a980c
ea2d860
ec8404a
13dc670
05d14bc
005fcd6
0ff85a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,54 +5,43 @@ | |
|
||
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; | ||
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; | ||
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; | ||
import static software.amazon.awssdk.http.SdkHttpMethod.POST; | ||
|
||
import java.io.BufferedReader; | ||
import java.io.InputStreamReader; | ||
import java.net.URI; | ||
import java.nio.charset.StandardCharsets; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.time.Duration; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.opensearch.OpenSearchStatusException; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.util.TokenBucket; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.connector.AwsConnector; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.exception.MLException; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.model.MLGuard; | ||
import org.opensearch.ml.common.output.model.ModelTensors; | ||
import org.opensearch.ml.engine.annotation.ConnectorExecutor; | ||
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; | ||
import org.opensearch.script.ScriptService; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder; | ||
import software.amazon.awssdk.core.sync.RequestBody; | ||
import software.amazon.awssdk.http.AbortableInputStream; | ||
import software.amazon.awssdk.http.HttpExecuteRequest; | ||
import software.amazon.awssdk.http.HttpExecuteResponse; | ||
import software.amazon.awssdk.http.SdkHttpClient; | ||
import software.amazon.awssdk.http.SdkHttpConfigurationOption; | ||
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; | ||
import software.amazon.awssdk.http.SdkHttpFullRequest; | ||
import software.amazon.awssdk.utils.AttributeMap; | ||
import software.amazon.awssdk.http.async.AsyncExecuteRequest; | ||
import software.amazon.awssdk.http.async.SdkAsyncHttpClient; | ||
|
||
@Log4j2 | ||
@ConnectorExecutor(AWS_SIGV4) | ||
public class AwsConnectorExecutor extends AbstractConnectorExecutor { | ||
|
||
@Getter | ||
private AwsConnector connector; | ||
private SdkHttpClient httpClient; | ||
@Setter | ||
@Getter | ||
private ScriptService scriptService; | ||
|
@@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { | |
@Getter | ||
private MLGuard mlGuard; | ||
|
||
public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { | ||
this.connector = (AwsConnector) connector; | ||
this.httpClient = httpClient; | ||
} | ||
private SdkAsyncHttpClient httpClient; | ||
|
||
public AwsConnectorExecutor(Connector connector) { | ||
super.initialize(connector); | ||
this.connector = (AwsConnector) connector; | ||
Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout()); | ||
Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout()); | ||
try ( | ||
AttributeMap attributeMap = AttributeMap | ||
.builder() | ||
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout) | ||
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout) | ||
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections()) | ||
.build() | ||
) { | ||
log | ||
.info( | ||
"Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}", | ||
connectionTimeout, | ||
readTimeout, | ||
super.getConnectorClientConfig().getMaxConnections() | ||
); | ||
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap); | ||
} catch (RuntimeException e) { | ||
log.error("Error initializing AWS connector HTTP client.", e); | ||
throw e; | ||
} catch (Throwable e) { | ||
log.error("Error initializing AWS connector HTTP client.", e); | ||
throw new MLException(e); | ||
} | ||
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); | ||
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); | ||
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); | ||
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); | ||
} | ||
|
||
@SuppressWarnings("removal") | ||
@Override | ||
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) { | ||
public void invokeRemoteModel( | ||
MLInput mlInput, | ||
Map<String, String> parameters, | ||
String payload, | ||
Map<Integer, ModelTensors> tensorOutputs, | ||
ExecutionContext countDownLatch, | ||
ActionListener<List<ModelTensors>> actionListener | ||
) { | ||
try { | ||
String endpoint = connector.getPredictEndpoint(parameters); | ||
RequestBody requestBody = RequestBody.fromString(payload); | ||
|
||
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest | ||
.builder() | ||
.method(POST) | ||
.uri(URI.create(endpoint)) | ||
.contentStreamProvider(requestBody.contentStreamProvider()); | ||
Map<String, String> headers = connector.getDecryptedHeaders(); | ||
if (headers != null) { | ||
for (String key : headers.keySet()) { | ||
builder.putHeader(key, headers.get(key)); | ||
} | ||
} | ||
SdkHttpFullRequest request = builder.build(); | ||
HttpExecuteRequest executeRequest = HttpExecuteRequest | ||
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes all invoke remote model is a "Post" request by hardcoding. Very soon we will add more action types with different methods. So do you want to extend this to support more HTTP Method through different action types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is based on the existing code: https://github.com/opensearch-project/ml-commons/blob/main/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java#L116, I think when we needs to support more different methods, we can change this at that time, current implementation can adapt to that change easily. |
||
AsyncExecuteRequest executeRequest = AsyncExecuteRequest | ||
.builder() | ||
.request(signRequest(request)) | ||
.contentStreamProvider(request.contentStreamProvider().orElse(null)) | ||
.requestContentPublisher(new SimpleHttpContentPublisher(request)) | ||
.responseHandler( | ||
new MLSdkAsyncHttpResponseHandler( | ||
countDownLatch, | ||
actionListener, | ||
parameters, | ||
tensorOutputs, | ||
connector, | ||
scriptService, | ||
mlGuard | ||
) | ||
) | ||
.build(); | ||
|
||
HttpExecuteResponse response = AccessController | ||
.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> httpClient.prepareRequest(executeRequest).call()); | ||
int statusCode = response.httpResponse().statusCode(); | ||
|
||
AbortableInputStream body = null; | ||
if (response.responseBody().isPresent()) { | ||
body = response.responseBody().get(); | ||
} | ||
|
||
StringBuilder responseBuilder = new StringBuilder(); | ||
if (body != null) { | ||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) { | ||
String line; | ||
while ((line = reader.readLine()) != null) { | ||
responseBuilder.append(line); | ||
} | ||
} | ||
} else { | ||
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); | ||
} | ||
String modelResponse = responseBuilder.toString(); | ||
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) { | ||
throw new IllegalArgumentException("guardrails triggered for LLM output"); | ||
} | ||
if (statusCode < 200 || statusCode >= 300) { | ||
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); | ||
} | ||
|
||
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); | ||
tensors.setStatusCode(statusCode); | ||
tensorOutputs.add(tensors); | ||
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest)); | ||
} catch (RuntimeException exception) { | ||
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); | ||
throw exception; | ||
actionListener.onFailure(exception); | ||
} catch (Throwable e) { | ||
log.error("Failed to execute predict in aws connector", e); | ||
throw new MLException("Fail to execute predict in aws connector", e); | ||
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; | ||
|
||
import java.io.IOException; | ||
import java.net.URI; | ||
import java.nio.charset.Charset; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
|
@@ -34,6 +36,7 @@ | |
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.model.MLGuard; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
import org.opensearch.ml.common.output.model.ModelTensors; | ||
import org.opensearch.script.ScriptService; | ||
|
@@ -46,7 +49,9 @@ | |
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; | ||
import software.amazon.awssdk.auth.signer.Aws4Signer; | ||
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams; | ||
import software.amazon.awssdk.core.sync.RequestBody; | ||
import software.amazon.awssdk.http.SdkHttpFullRequest; | ||
import software.amazon.awssdk.http.SdkHttpMethod; | ||
import software.amazon.awssdk.regions.Region; | ||
|
||
@Log4j2 | ||
|
@@ -179,11 +184,15 @@ public static ModelTensors processOutput( | |
String modelResponse, | ||
Connector connector, | ||
ScriptService scriptService, | ||
Map<String, String> parameters | ||
Map<String, String> parameters, | ||
MLGuard mlGuard | ||
) throws IOException { | ||
if (modelResponse == null) { | ||
throw new IllegalArgumentException("model response is null"); | ||
} | ||
if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) { | ||
throw new IllegalArgumentException("guardrails triggered for LLM output"); | ||
} | ||
List<ModelTensor> modelTensors = new ArrayList<>(); | ||
Optional<ConnectorAction> predictAction = connector.findPredictAction(); | ||
if (predictAction.isEmpty()) { | ||
|
@@ -252,4 +261,42 @@ public static SdkHttpFullRequest signRequest( | |
|
||
return signer.sign(request, params); | ||
} | ||
|
||
public static SdkHttpFullRequest buildSdkRequest( | ||
Connector connector, | ||
Map<String, String> parameters, | ||
String payload, | ||
SdkHttpMethod method | ||
) { | ||
String charset = parameters.getOrDefault("charset", "UTF-8"); | ||
RequestBody requestBody; | ||
if (payload != null) { | ||
requestBody = RequestBody.fromString(payload, Charset.forName(charset)); | ||
} else { | ||
requestBody = RequestBody.empty(); | ||
} | ||
if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any chance of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, even when payload is null, we create requestBody with |
||
log.error("Content length is 0. Aborting request to remote model"); | ||
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); | ||
} | ||
String endpoint = connector.getPredictEndpoint(parameters); | ||
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest | ||
.builder() | ||
.method(method) | ||
.uri(URI.create(endpoint)) | ||
.contentStreamProvider(requestBody.contentStreamProvider()); | ||
Map<String, String> headers = connector.getDecryptedHeaders(); | ||
if (headers != null) { | ||
for (String key : headers.keySet()) { | ||
builder.putHeader(key, headers.get(key)); | ||
} | ||
} | ||
if (builder.matchingHeaders("Content-Type").isEmpty()) { | ||
builder.putHeader("Content-Type", "application/json"); | ||
} | ||
if (builder.matchingHeaders("Content-Length").isEmpty()) { | ||
builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this mandatory ? We don't have this part before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently no LLM accepts empty request body, adding this is an enhancement to our code which means we should have this in the first place. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite understand, if need to check request body, you can check body length directly, why have to add a header ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aasync httpclient doesn't set the |
||
} | ||
return builder.build(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* | ||
* * Copyright OpenSearch Contributors | ||
* * SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
import java.util.concurrent.CountDownLatch; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Data; | ||
|
||
/** | ||
* This class encapsulates several parameters that are used in a split-batch request case. | ||
* A batch request is that in neural-search side multiple fields are send in one request to ml-commons, | ||
* but the remote model doesn't accept list of string inputs so in ml-commons the request needs split. | ||
* sequence is used to identify the index of the split request. | ||
* countDownLatch is used to wait for all the split requests to finish. | ||
* exceptionHolder is used to hold any exception thrown in a split-batch request. | ||
*/ | ||
@Data | ||
@AllArgsConstructor | ||
public class ExecutionContext { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add some comment about what is this class for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
// Should never be null | ||
private int sequence; | ||
private CountDownLatch countDownLatch; | ||
// This is to hold any exception thrown in a split-batch request | ||
private AtomicReference<Exception> exceptionHolder; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov shows this default exception is not covered. Also what is the necessity here to default an exception? When the method is not implemented, the compile would fail first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tests for this. No, this is a runtime error, compilation won't be impacted by this. Below are reasons:
First of all, this is an interface method but not all subcalsses need to implement this method, e.g. RemoteModel. So we need to make it a default method, and this method can be accessed in all its implemented classes. Also this method needs return value, there are three options:
In fact, there's another option which is keep this method abstract and all subclasses needs to implement this and based on their case, they can throw UnsupportedOperationException which can also avoid accidentally abuse of the method, but this option we need to change quite a lot of files, and since this also works I choose this simpler option.