Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change httpclient to async #1958

Merged
merged 60 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
bccfd16
Change httpclient from sync to async
zane-neo Jan 9, 2024
425a450
Change from CRTAsyncHttpClient to NettyAsyncHttpClient
zane-neo Jan 9, 2024
0600d63
Add publisher to request
zane-neo Jan 10, 2024
864e4bb
Change sync httpclient to async
zane-neo Jan 15, 2024
2fe466b
Handle error case and return error response in actionLListener
zane-neo Jan 23, 2024
fa8d3ea
Fix no response when exception
zane-neo Jan 24, 2024
b761a5d
Add content type header
zane-neo Jan 24, 2024
02f9602
Fix issues found in functional test
zane-neo Jan 25, 2024
3ee19ce
Fix no response issue in functional test
zane-neo Jan 26, 2024
136830c
fix default step size error
zane-neo Jan 26, 2024
2941daf
Add track inference duration for async httpclient
zane-neo Jan 26, 2024
6b54018
Change client appsec highlight issues implementation for async httpcl…
zane-neo Jan 26, 2024
0bd49aa
Add UTs
zane-neo Jan 29, 2024
5ecf3f2
Add UTs
zane-neo Jan 30, 2024
a29b851
Remove unused file
zane-neo Jan 30, 2024
6ff57d9
Add UTs
zane-neo Jan 30, 2024
f0ef57e
format code
zane-neo Jan 30, 2024
caf05a0
Change error code to honor remote service error code
zane-neo Jan 31, 2024
cadb989
Add more UTs
zane-neo Feb 4, 2024
d8a1281
Change SSRF code to make it correct for return error stattus
zane-neo Feb 4, 2024
c7e0588
Fix failure UTs and add more UTs
zane-neo Feb 4, 2024
e59e68c
Fix failure ITs
zane-neo Feb 4, 2024
42d08f8
format code
zane-neo Feb 4, 2024
f100187
Fix partial success response not correct issue
zane-neo Feb 5, 2024
e221284
format code
zane-neo Feb 5, 2024
3f0a252
Fix failure ITs
zane-neo Feb 5, 2024
802c537
Add more UTs to increase code coverage
zane-neo Feb 5, 2024
3c76c22
Change url regex
zane-neo Feb 20, 2024
7df33cd
Address comments
zane-neo Feb 20, 2024
ffd8565
format code
zane-neo Feb 21, 2024
3199a28
Fix failure UTs
zane-neo Feb 21, 2024
3eb3b7f
Add UT for httpclientFactory throw exception when creating httpclient
zane-neo Feb 21, 2024
80f4c1f
format code
zane-neo Feb 21, 2024
7de1ab4
Address comments and add modelTensor status code
zane-neo Feb 23, 2024
0ee363a
Address comments
zane-neo Feb 26, 2024
05a3111
format code
zane-neo Feb 28, 2024
b54c6b1
Add status code to process error response
zane-neo Feb 29, 2024
6a6287e
format code
zane-neo Feb 29, 2024
e3e98a1
Rebase main after connector level http parameter support
zane-neo Mar 7, 2024
8ffe6d7
Fix UT
zane-neo Mar 7, 2024
bebb913
Change error message when remote model return empty and chaange the b…
zane-neo Mar 9, 2024
25d6845
Add comments\
zane-neo Mar 9, 2024
031ab19
Remove redundant builder and change the error code check
zane-neo Mar 13, 2024
421b00b
format code
zane-neo Mar 13, 2024
71f3942
Add more UTs for throw exception cases
zane-neo Mar 13, 2024
20a6172
fix failure UTs
zane-neo Mar 13, 2024
80b91ed
format code
zane-neo Mar 13, 2024
e446f49
Fix test cases since the error message change
zane-neo Apr 11, 2024
e138d83
Rebase code
zane-neo Apr 11, 2024
f0f3a5b
fix failure IT
zane-neo Apr 12, 2024
b3aed63
Add more UTs
zane-neo Apr 12, 2024
d12fb86
Fix duplicate response to client issue
zane-neo Apr 26, 2024
5e6a7f6
fix duplicate response in channel
zane-neo Apr 26, 2024
a7a980c
change code for all successfully responses case
zane-neo Apr 26, 2024
ea2d860
Address comments
zane-neo Apr 27, 2024
ec8404a
format code
zane-neo Apr 27, 2024
13dc670
Increase nio httpclient version to fix vulnerbility
zane-neo Apr 28, 2024
05d14bc
Change validate localhost logic to same with existing code
zane-neo Apr 29, 2024
005fcd6
change method signature to private
zane-neo Apr 29, 2024
0ff85a8
format code
zane-neo Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ dependencies {
}
}

implementation platform('software.amazon.awssdk:bom:2.21.15')
implementation platform('software.amazon.awssdk:bom:2.22.12')
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
implementation 'software.amazon.awssdk:auth'
implementation 'software.amazon.awssdk:apache-client'
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.22.12'
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;

/**
Expand All @@ -31,7 +33,13 @@ public interface Predictable {
* @param mlInput input data
* @return predicted results
*/
MLOutput predict(MLInput mlInput);
default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codecov shows this default exception is not covered. Also what is the necessity here to default an exception? When the method is not implemented, the compile would fail first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for this. No, this is a runtime error, compilation won't be impacted by this. Below are reasons:
First of all, this is an interface method but not all subcalsses need to implement this method, e.g. RemoteModel. So we need to make it a default method, and this method can be accessed in all its implemented classes. Also this method needs return value, there are three options:

  1. return null: this is not good since this might take long time for someone to identify the root cause.
  2. return a manually created object: this is not good since we don't want to manually create an object.
  3. throw exception: this is good since the exception and error message is clear and searchable for developer/user thus they can identify the issue quickly.

In fact, there's another option which is keep this method abstract and all subclasses needs to implement this and based on their case, they can throw UnsupportedOperationException which can also avoid accidentally abuse of the method, but this option we need to change quite a lot of files, and since this also works I choose this simpler option.

}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
}

/**
* Init model (load model into memory) with ML model content and params.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,43 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;

@Log4j2
@ConnectorExecutor(AWS_SIGV4)
public class AwsConnectorExecutor extends AbstractConnectorExecutor {

@Getter
private AwsConnector connector;
private SdkHttpClient httpClient;
@Setter
@Getter
private ScriptService scriptService;
Expand All @@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
@Getter
private MLGuard mlGuard;

public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
this.connector = (AwsConnector) connector;
this.httpClient = httpClient;
}
private SdkAsyncHttpClient httpClient;

public AwsConnectorExecutor(Connector connector) {
super.initialize(connector);
this.connector = (AwsConnector) connector;
Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout());
try (
AttributeMap attributeMap = AttributeMap
.builder()
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections())
.build()
) {
log
.info(
"Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}",
connectionTimeout,
readTimeout,
super.getConnectorClientConfig().getMaxConnections()
);
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
} catch (RuntimeException e) {
log.error("Error initializing AWS connector HTTP client.", e);
throw e;
} catch (Throwable e) {
log.error("Error initializing AWS connector HTTP client.", e);
throw new MLException(e);
}
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
}

@SuppressWarnings("removal")
@Override
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
public void invokeRemoteModel(
MLInput mlInput,
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
String endpoint = connector.getPredictEndpoint(parameters);
RequestBody requestBody = RequestBody.fromString(payload);

SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(POST)
.uri(URI.create(endpoint))
.contentStreamProvider(requestBody.contentStreamProvider());
Map<String, String> headers = connector.getDecryptedHeaders();
if (headers != null) {
for (String key : headers.keySet()) {
builder.putHeader(key, headers.get(key));
}
}
SdkHttpFullRequest request = builder.build();
HttpExecuteRequest executeRequest = HttpExecuteRequest
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes all invoke remote model is a "Post" request by hardcoding. Very soon we will add more action types with different methods. So do you want to extend this to support more HTTP Method through different action types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is based on the existing code: https://github.com/opensearch-project/ml-commons/blob/main/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java#L116, I think when we needs to support more different methods, we can change this at that time, current implementation can adapt to that change easily.

AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
.contentStreamProvider(request.contentStreamProvider().orElse(null))
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
mlGuard
)
)
.build();

HttpExecuteResponse response = AccessController
.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> httpClient.prepareRequest(executeRequest).call());
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
body = response.responseBody().get();
}

StringBuilder responseBuilder = new StringBuilder();
if (body != null) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
responseBuilder.append(line);
}
}
} else {
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
}
String modelResponse = responseBuilder.toString();
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
if (statusCode < 200 || statusCode >= 300) {
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
throw exception;
actionListener.onFailure(exception);
} catch (Throwable e) {
log.error("Failed to execute predict in aws connector", e);
throw new MLException("Fail to execute predict in aws connector", e);
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -34,8 +37,10 @@
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import com.jayway.jsonpath.JsonPath;
Expand All @@ -46,7 +51,9 @@
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;

@Log4j2
Expand Down Expand Up @@ -179,11 +186,15 @@ public static ModelTensors processOutput(
String modelResponse,
Connector connector,
ScriptService scriptService,
Map<String, String> parameters
Map<String, String> parameters,
MLGuard mlGuard
) throws IOException {
if (modelResponse == null) {
throw new IllegalArgumentException("model response is null");
}
if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) {
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
List<ModelTensor> modelTensors = new ArrayList<>();
Optional<ConnectorAction> predictAction = connector.findPredictAction();
if (predictAction.isEmpty()) {
Expand Down Expand Up @@ -252,4 +263,47 @@ public static SdkHttpFullRequest signRequest(

return signer.sign(request, params);
}

public static SdkHttpFullRequest buildSdkRequest(
Connector connector,
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) throws Exception {
String endpoint = connector.getPredictEndpoint(parameters);
URL url = new URL(endpoint);
String protocol = url.getProtocol();
String host = url.getHost();
int port = url.getPort();
MLHttpClientFactory.validate(protocol, host, port);
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
if (payload != null) {
requestBody = RequestBody.fromString(payload, Charset.forName(charset));
} else {
requestBody = RequestBody.empty();
}
if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance of optionalContentLength to be null? if yes then it might skip this checking?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, even when payload is null, we create requestBody with RequestBody.empty() which is not null, and I use this method on purpose to make sure the requestBody never be null.

log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
.uri(URI.create(endpoint))
.contentStreamProvider(requestBody.contentStreamProvider());
Map<String, String> headers = connector.getDecryptedHeaders();
if (headers != null) {
for (String key : headers.keySet()) {
builder.putHeader(key, headers.get(key));
}
}
if (builder.matchingHeaders("Content-Type").isEmpty()) {
builder.putHeader("Content-Type", "application/json");
}
if (builder.matchingHeaders("Content-Length").isEmpty()) {
builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this mandatory ? We don't have this part before

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently no LLM accepts empty request body, adding this is an enhancement to our code which means we should have this in the first place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite understand, if need to check request body, you can check body length directly, why have to add a header ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aasync httpclient doesn't set the Content-Length automatically, we need to do this manually, otherwise we're not able to get correct response from remote model side.

}
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.engine.algorithms.remote;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class ExecutionContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some comment about what is this class for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

// Should never be null
private int sequence;
private CountDownLatch countDownLatch;
// This is to hold any exception thrown in a split-batch request
private AtomicReference<Exception> exceptionHolder;
}
Loading
Loading