Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test on personal repo #3

Closed
wants to merge 13 commits into from
31 changes: 18 additions & 13 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
java: [11, 17, 21]

name: Build and Test MLCommons Plugin on linux
if: github.repository == 'opensearch-project/ml-commons'
#if: github.repository == 'opensearch-project/ml-commons'
environment: ml-commons-cicd-env
outputs:
build-test-linux: ${{ steps.step-build-test-linux.outputs.build-test-linux }}
Expand All @@ -44,10 +44,10 @@ jobs:
with:
java-version: ${{ matrix.java }}

- uses: aws-actions/configure-aws-credentials@v2
with:
role-to-assume: ${{ secrets.ML_ROLE }}
aws-region: us-west-2
# - uses: aws-actions/configure-aws-credentials@v2
# with:
# role-to-assume: ${{ secrets.ML_ROLE }}
# aws-region: us-west-2

- name: Checkout MLCommons
uses: actions/checkout@v3
Expand Down Expand Up @@ -91,15 +91,15 @@ jobs:
java: [11, 17, 21]

name: Test MLCommons Plugin on linux docker
if: github.repository == 'opensearch-project/ml-commons'
# if: github.repository == 'opensearch-project/ml-commons'
environment: ml-commons-cicd-env
runs-on: ubuntu-latest

steps:
- uses: aws-actions/configure-aws-credentials@v2
with:
role-to-assume: ${{ secrets.ML_ROLE }}
aws-region: us-west-2
# - uses: aws-actions/configure-aws-credentials@v2
# with:
# role-to-assume: ${{ secrets.ML_ROLE }}
# aws-region: us-west-2

- name: Checkout MLCommons
uses: actions/checkout@v3
Expand Down Expand Up @@ -141,24 +141,29 @@ jobs:
else
echo "imagePresent=false" >> $GITHUB_ENV
fi
- name: Generate Password For Admin
id: genpass
run: |
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
- name: Run Docker Image
if: env.imagePresent == 'true'
run: |
cd ..
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" opensearch-ml:test
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" -e OPENSEARCH_INITIAL_ADMIN_PASSWORD=${{ steps.genpass.outputs.password }} opensearch-ml:test
sleep 90
- name: Run MLCommons Test
if: env.imagePresent == 'true'
run: |
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:admin --insecure |grep opensearch-security|wc -l`
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:${{ steps.genpass.outputs.password }} --insecure |grep opensearch-security|wc -l`
export OPENAI_KEY=$(aws secretsmanager get-secret-value --secret-id github_openai_key --query SecretString --output text)
export COHERE_KEY=$(aws secretsmanager get-secret-value --secret-id github_cohere_key --query SecretString --output text)
echo "::add-mask::$OPENAI_KEY"
echo "::add-mask::$COHERE_KEY"
if [ $security -gt 0 ]
then
echo "Security plugin is available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=${{ steps.genpass.outputs.password }}
else
echo "Security plugin is NOT available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ public void createConversation(String name, ActionListener<String> listener) {
public void getConversations(int from, int maxResults, ActionListener<List<ConversationMeta>> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(List.of());
return;
}
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = getUserStrFromThreadContext();
Expand Down Expand Up @@ -250,6 +251,7 @@ public void getConversations(int maxResults, ActionListener<List<ConversationMet
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(true);
return;
}
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = getUserStrFromThreadContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public String encrypt(String plainText) {
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();
byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
// https://github.com/aws/aws-encryption-sdk-java/issues/1879
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> encryptResult = crypto
.encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8));
Expand All @@ -81,7 +82,7 @@ public String decrypt(String encryptedText) {
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();

byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> decryptedResult = crypto
.decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1222,21 +1222,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user2Client,
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);
} catch (IOException e) {
assertNull(e);
}
// User2 fails to get model group
try {
getModelGroup(user3Client, modelGroupId, null);
getModelGroup(user3Client, modelGroupId1, null);
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(
Expand All @@ -1256,21 +1256,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user1Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User3 successfully gets model group
getModelGroup(
user3Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User4 successfully gets model group
getModelGroup(
user4Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1286,14 +1286,14 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user3Client,
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1320,7 +1320,7 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
client(),
modelGroupId4,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup4"); }
);
} catch (IOException e) {
assertNull(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Response;
Expand All @@ -23,6 +24,7 @@

import com.google.common.collect.ImmutableList;

@Ignore
public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {

private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
Expand All @@ -39,7 +41,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand Down Expand Up @@ -251,6 +253,7 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
assertNotNull(responseMap);
}

@Ignore("text-davinci-edit-001 been deprecated on 2024-01-04 and replaced by /v1/chat/completions")
public void testOpenAIEditsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -120,16 +122,20 @@ public void testConversations_MorePages() throws IOException {
assert (((Double) map.get("next_token")).intValue() == 1);
}

public void testGetConversations_nextPage() throws IOException {
public void testGetConversations_nextPage() throws IOException, InterruptedException {
Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse1 != null);
assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK);
HttpEntity cchttpEntity1 = ccresponse1.getEntity();
String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1);
Map ccmap1 = gson.fromJson(ccentityString1, Map.class);
assert (ccmap1.containsKey("conversation_id"));
logger.info("ccentityString1={}", ccentityString1);
String id1 = (String) ccmap1.get("conversation_id");

// wait for 0.1s to make sure update time is different between conversation 1 and 2
TimeUnit.MICROSECONDS.sleep(100);

Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse2 != null);
assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK);
Expand Down Expand Up @@ -159,7 +165,7 @@ public void testGetConversations_nextPage() throws IOException {
ArrayList<Map> conversations1 = (ArrayList<Map>) map1.get("conversations");
assert (conversations1.size() == 1);
assert (conversations1.get(0).containsKey("conversation_id"));
assert (((String) conversations1.get(0).get("conversation_id")).equals(id2));
Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2);
assert (((Double) map1.get("next_token")).intValue() == 1);

Response response = TestHelper
Expand Down
Loading