diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 310fbfb181..082fe04535 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -167,6 +167,7 @@ private MLCommonsSettings() {} "^https://api\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$", + "^https://api\\.deepseek\\.com/.*$", "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", "^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", "^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index f49c5b7fc6..3927bb68ab 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -123,7 +123,8 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$", - REKOGNITION_TRUST_ENDPOINT_REGEX + REKOGNITION_TRUST_ENDPOINT_REGEX, + "^https://api\\.deepseek\\.com/.*$" ); @Before @@ -546,6 +547,55 @@ public void test_execute_URL_notMatchingExpression_exception() { ); } + public void test_connector_creation_success_deepseek() { + TransportCreateConnectorAction action = new TransportCreateConnectorAction( + transportService, + actionFilters, + mlIndicesHandler, + client, + sdkClient, + mlEngine, + connectorAccessControlHelper, + settings, + clusterService, + mlModelManager, + mlFeatureEnabledSetting + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); + List actions = new ArrayList<>(); + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.deepseek.com/v1/chat/completions") + .build() + ); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name(randomAlphaOfLength(5)) + .description(randomAlphaOfLength(10)) + .version("1") + .protocol(ConnectorProtocols.HTTP) + .credential(credential) + .actions(actions) + .build(); + MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); + } + public void test_connector_creation_success_rekognition() { TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 0cf51a8549..9258e790f8 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -90,6 +90,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { static final Version MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES = CommonValue.VERSION_2_18_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS = CommonValue.VERSION_2_13_0; + @Setter @Getter private String conversationId; @@ -200,16 +202,23 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.llmQuestion = input.readString(); } - this.systemPrompt = input.readOptionalString(); - this.userInstructions = input.readOptionalString(); + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + this.systemPrompt = input.readOptionalString(); + this.userInstructions = input.readOptionalString(); + } + this.contextSize = input.readInt(); this.interactionSize = input.readInt(); this.timeout = input.readInt(); - this.llmResponseField = input.readOptionalString(); + + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + this.llmResponseField = input.readOptionalString(); + } if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) { this.llmMessages.addAll(input.readList(MessageBlock::new)); } + } @Override @@ -272,12 +281,18 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(llmQuestion); } - out.writeOptionalString(systemPrompt); - out.writeOptionalString(userInstructions); + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + out.writeOptionalString(systemPrompt); + out.writeOptionalString(userInstructions); + } + out.writeInt(contextSize); out.writeInt(interactionSize); out.writeInt(timeout); - out.writeOptionalString(llmResponseField); + + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + out.writeOptionalString(llmResponseField); + } if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) { out.writeList(llmMessages); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 00ae5f27b7..4d63a0157f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -121,6 +121,7 @@ public void testMiscMethods() throws IOException { assertNotEquals(builder1, builder2); assertNotEquals(builder1.hashCode(), builder2.hashCode()); + // BWC test for bedrock converse params StreamOutput so1 = mock(StreamOutput.class); when(so1.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES); builder1.writeTo(so1); @@ -130,6 +131,19 @@ public void testMiscMethods() throws IOException { when(so2.getVersion()).thenReturn(Version.V_2_17_0); builder1.writeTo(so2); verify(so2, times(5)).writeOptionalString(any()); + + // BWC test for system prompt and instructions + StreamOutput so3 = mock(StreamOutput.class); + when(so3.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS); + builder1.writeTo(so3); + verify(so3, times(5)).writeOptionalString(any()); + verify(so3, times(1)).writeString(any()); + + StreamOutput so4 = mock(StreamOutput.class); + when(so4.getVersion()).thenReturn(Version.V_2_12_0); + builder1.writeTo(so4); + verify(so4, times(2)).writeOptionalString(any()); + verify(so4, times(1)).writeString(any()); } public void testParse() throws IOException {