Skip to content

Commit

Permalink
add more it
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 4, 2023
1 parent 5022b86 commit 9fa892c
Showing 1 changed file with 30 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
Expand Down Expand Up @@ -384,7 +385,30 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio
assertTrue((Boolean) responseMap.get("violence"));
}

public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException {
public void testOpenAITextEmbeddingModel_UTF8() throws IOException, InterruptedException {
testOpenAITextEmbeddingModel("UTF-8", (responseMap) -> {
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("data");
assertFalse(responseList.isEmpty());
return null;
});
}

public void testOpenAITextEmbeddingModel_ISO8859_1() throws IOException, InterruptedException {
testOpenAITextEmbeddingModel("ISO-8859-1", (responseMap) -> {
assertFalse(responseMap.containsKey("inference_results"));
Map error = (Map) responseMap.get("error");
String reason = (String) error.get("reason");
assertTrue(reason.contains("'utf-8' codec can't decode byte 0xeb"));
return null;
});
}

private void testOpenAITextEmbeddingModel(String charset, Function<Map, Void> verificationFunction) throws IOException,
InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
Expand Down Expand Up @@ -432,17 +456,15 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"input\": [\"This is a string containing Moët Hennessy\"]\n"
+ " \"input\": [\"This is a string containing Moët Hennessy\"],\n"
+ " \"charset\": \""
+ charset
+ "\"\n"
+ " }\n"
+ "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("data");
assertFalse(responseList.isEmpty());
verificationFunction.apply(responseMap);
}

public void testCohereGenerateTextModel() throws IOException, InterruptedException {
Expand Down

0 comments on commit 9fa892c

Please sign in to comment.