Skip to content

Commit

Permalink
support charset input params and change default charset as utf8
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Nov 23, 2023
1 parent 7cc9399 commit 5022b86
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
try {
String predictEndpoint = connector.getPredictEndpoint(parameters);
request = new HttpPost(predictEndpoint);
HttpEntity entity = new StringEntity(payload);
String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8";
HttpEntity entity = new StringEntity(payload, charset);
((HttpPost) request).setEntity(entity);
} catch (Exception e) {
throw new MLException("Failed to create http request for remote model", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Response;
Expand Down Expand Up @@ -385,7 +384,6 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio
assertTrue((Boolean) responseMap.get("violence"));
}

@Ignore
public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand All @@ -397,9 +395,6 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
+ " \"version\": 1,\n"
+ " \"protocol\": \"http\",\n"
+ " \"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"model\": \"text-embedding-ada-002\"\n"
+ " },\n"
+ " \"credential\": {\n"
Expand All @@ -415,9 +410,9 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
+ " \"headers\": { \n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n"
+ " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n"
+ " \"post_process_function\": \"openai_embedding\"\n"
+ " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\",\n"
+ " \"pre_process_function\": \"connector.pre_process.openai.embedding\",\n"
+ " \"post_process_function\": \"connector.post_process.openai.embedding\"\n"
+ " }\n"
+ " ]\n"
+ "}";
Expand All @@ -435,17 +430,19 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}";
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"input\": [\"This is a string containing Moët Hennessy\"]\n"
+ " }\n"
+ "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("data");
responseMap = (Map) responseList.get(0);
assertFalse(((List) responseMap.get("embedding")).isEmpty());
assertFalse(responseList.isEmpty());
}

public void testCohereGenerateTextModel() throws IOException, InterruptedException {
Expand Down

0 comments on commit 5022b86

Please sign in to comment.