diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index da4d64c33c..40b044a966 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -26,7 +26,7 @@ jobs: java: [11, 17, 21] name: Build and Test MLCommons Plugin on linux - if: github.repository == 'opensearch-project/ml-commons' + #if: github.repository == 'opensearch-project/ml-commons' environment: ml-commons-cicd-env outputs: build-test-linux: ${{ steps.step-build-test-linux.outputs.build-test-linux }} @@ -44,10 +44,10 @@ jobs: with: java-version: ${{ matrix.java }} - - uses: aws-actions/configure-aws-credentials@v2 - with: - role-to-assume: ${{ secrets.ML_ROLE }} - aws-region: us-west-2 +# - uses: aws-actions/configure-aws-credentials@v2 +# with: +# role-to-assume: ${{ secrets.ML_ROLE }} +# aws-region: us-west-2 - name: Checkout MLCommons uses: actions/checkout@v3 @@ -91,15 +91,15 @@ jobs: java: [11, 17, 21] name: Test MLCommons Plugin on linux docker - if: github.repository == 'opensearch-project/ml-commons' + # if: github.repository == 'opensearch-project/ml-commons' environment: ml-commons-cicd-env runs-on: ubuntu-latest steps: - - uses: aws-actions/configure-aws-credentials@v2 - with: - role-to-assume: ${{ secrets.ML_ROLE }} - aws-region: us-west-2 +# - uses: aws-actions/configure-aws-credentials@v2 +# with: +# role-to-assume: ${{ secrets.ML_ROLE }} +# aws-region: us-west-2 - name: Checkout MLCommons uses: actions/checkout@v3 @@ -141,16 +141,21 @@ jobs: else echo "imagePresent=false" >> $GITHUB_ENV fi + - name: Generate Password For Admin + id: genpass + run: | + PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') + echo "password={$PASSWORD}" >> $GITHUB_OUTPUT - name: Run Docker Image if: env.imagePresent == 'true' run: | cd .. - docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" opensearch-ml:test + docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" -e OPENSEARCH_INITIAL_ADMIN_PASSWORD=${{ steps.genpass.outputs.password }} opensearch-ml:test sleep 90 - name: Run MLCommons Test if: env.imagePresent == 'true' run: | - security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:admin --insecure |grep opensearch-security|wc -l` + security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:${{ steps.genpass.outputs.password }} --insecure |grep opensearch-security|wc -l` export OPENAI_KEY=$(aws secretsmanager get-secret-value --secret-id github_openai_key --query SecretString --output text) export COHERE_KEY=$(aws secretsmanager get-secret-value --secret-id github_cohere_key --query SecretString --output text) echo "::add-mask::$OPENAI_KEY" @@ -158,7 +163,7 @@ jobs: if [ $security -gt 0 ] then echo "Security plugin is available" - ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin + ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=${{ steps.genpass.outputs.password }} else echo "Security plugin is NOT available" ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index d262b816ec..f7e2a63138 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -198,6 +198,7 @@ public void createConversation(String name, ActionListener listener) { public void getConversations(int from, int maxResults, ActionListener> listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); + return; } SearchRequest request = Requests.searchRequest(META_INDEX_NAME); String userstr = getUserStrFromThreadContext(); @@ -250,6 +251,7 @@ public void getConversations(int maxResults, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); + return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); String userstr = getUserStrFromThreadContext(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index b500709bc5..617c6871e5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -68,7 +68,8 @@ public String encrypt(String plainText) { initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); byte[] bytes = Base64.getDecoder().decode(masterKey); - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); + // https://github.com/aws/aws-encryption-sdk-java/issues/1879 + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); final CryptoResult encryptResult = crypto .encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8)); @@ -81,7 +82,7 @@ public String decrypt(String encryptedText) { final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); byte[] bytes = Base64.getDecoder().decode(masterKey); - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); final CryptoResult decryptedResult = crypto .decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java index ac8b8057a5..861eb7ec60 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java @@ -1222,21 +1222,21 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user2Client, modelGroupId1, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); } ); // Admin successfully gets model group getModelGroup( client(), modelGroupId1, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); } ); } catch (IOException e) { assertNull(e); } // User2 fails to get model group try { - getModelGroup(user3Client, modelGroupId, null); + getModelGroup(user3Client, modelGroupId1, null); } catch (Exception e) { assertEquals(ResponseException.class, e.getClass()); assertTrue( @@ -1256,21 +1256,21 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user1Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); // User3 successfully gets model group getModelGroup( user3Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); // User4 successfully gets model group getModelGroup( user4Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); } catch (IOException e) { assertNull(e); @@ -1286,14 +1286,14 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user3Client, modelGroupId3, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); } ); // Admin successfully gets model group getModelGroup( client(), modelGroupId3, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); } ); } catch (IOException e) { assertNull(e); @@ -1320,7 +1320,7 @@ public void test_get_modelGroup() throws IOException { getModelGroup( client(), modelGroupId4, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup4"); } ); } catch (IOException e) { assertNull(e); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 0c8a7c779c..637f0e9f36 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -15,6 +15,7 @@ import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; @@ -23,6 +24,7 @@ import com.google.common.collect.ImmutableList; +@Ignore public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); @@ -39,7 +41,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + " \"content_type\": \"application/json\",\n" + " \"max_tokens\": 7,\n" + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" + + " \"model\": \"gpt-3.5-turbo-instruct\"\n" + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" @@ -251,6 +253,7 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep assertNotNull(responseMap); } + @Ignore("text-davinci-edit-001 been deprecated on 2024-01-04 and replaced by /v1/chat/completions") public void testOpenAIEditsModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java index 2112793166..2b2f409908 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.apache.hc.core5.http.HttpEntity; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; @@ -120,7 +122,7 @@ public void testConversations_MorePages() throws IOException { assert (((Double) map.get("next_token")).intValue() == 1); } - public void testGetConversations_nextPage() throws IOException { + public void testGetConversations_nextPage() throws IOException, InterruptedException { Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse1 != null); assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK); @@ -128,8 +130,12 @@ public void testGetConversations_nextPage() throws IOException { String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1); Map ccmap1 = gson.fromJson(ccentityString1, Map.class); assert (ccmap1.containsKey("conversation_id")); + logger.info("ccentityString1={}", ccentityString1); String id1 = (String) ccmap1.get("conversation_id"); + // wait for 0.1s to make sure update time is different between conversation 1 and 2 + TimeUnit.MICROSECONDS.sleep(100); + Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse2 != null); assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK); @@ -159,7 +165,7 @@ public void testGetConversations_nextPage() throws IOException { ArrayList conversations1 = (ArrayList) map1.get("conversations"); assert (conversations1.size() == 1); assert (conversations1.get(0).containsKey("conversation_id")); - assert (((String) conversations1.get(0).get("conversation_id")).equals(id2)); + Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2); assert (((Double) map1.get("next_token")).intValue() == 1); Response response = TestHelper