Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Pavan Yekbote <[email protected]>
  • Loading branch information
pyek-bot committed Jan 28, 2025
2 parents 83332ff + 17251cd commit a083b32
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ private MLCommonsSettings() {}
"^https://api\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
"^https://api\\.deepseek\\.com/.*$",
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase {
"^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
REKOGNITION_TRUST_ENDPOINT_REGEX
REKOGNITION_TRUST_ENDPOINT_REGEX,
"^https://api\\.deepseek\\.com/.*$"
);

@Before
Expand Down Expand Up @@ -546,6 +547,55 @@ public void test_execute_URL_notMatchingExpression_exception() {
);
}

public void test_connector_creation_success_deepseek() {
TransportCreateConnectorAction action = new TransportCreateConnectorAction(
transportService,
actionFilters,
mlIndicesHandler,
client,
sdkClient,
mlEngine,
connectorAccessControlHelper,
settings,
clusterService,
mlModelManager,
mlFeatureEnabledSetting
);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class));
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(client).index(any(IndexRequest.class), isA(ActionListener.class));
List<ConnectorAction> actions = new ArrayList<>();
actions
.add(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.deepseek.com/v1/chat/completions")
.build()
);
Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput
.builder()
.name(randomAlphaOfLength(5))
.description(randomAlphaOfLength(10))
.version("1")
.protocol(ConnectorProtocols.HTTP)
.credential(credential)
.actions(actions)
.build();
MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);
action.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
}

public void test_connector_creation_success_rekognition() {
TransportCreateConnectorAction action = new TransportCreateConnectorAction(
transportService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {

static final Version MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES = CommonValue.VERSION_2_18_0;

public static final Version MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS = CommonValue.VERSION_2_13_0;

@Setter
@Getter
private String conversationId;
Expand Down Expand Up @@ -200,16 +202,23 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.llmQuestion = input.readString();
}

this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) {
this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
}

this.contextSize = input.readInt();
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) {
this.llmResponseField = input.readOptionalString();
}

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
this.llmMessages.addAll(input.readList(MessageBlock::new));
}

}

@Override
Expand Down Expand Up @@ -272,12 +281,18 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(llmQuestion);
}

out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) {
out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
}

out.writeInt(contextSize);
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) {
out.writeOptionalString(llmResponseField);
}

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
out.writeList(llmMessages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ public void testMiscMethods() throws IOException {
assertNotEquals(builder1, builder2);
assertNotEquals(builder1.hashCode(), builder2.hashCode());

// BWC test for bedrock converse params
StreamOutput so1 = mock(StreamOutput.class);
when(so1.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
builder1.writeTo(so1);
Expand All @@ -130,6 +131,19 @@ public void testMiscMethods() throws IOException {
when(so2.getVersion()).thenReturn(Version.V_2_17_0);
builder1.writeTo(so2);
verify(so2, times(5)).writeOptionalString(any());

// BWC test for system prompt and instructions
StreamOutput so3 = mock(StreamOutput.class);
when(so3.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS);
builder1.writeTo(so3);
verify(so3, times(5)).writeOptionalString(any());
verify(so3, times(1)).writeString(any());

StreamOutput so4 = mock(StreamOutput.class);
when(so4.getVersion()).thenReturn(Version.V_2_12_0);
builder1.writeTo(so4);
verify(so4, times(2)).writeOptionalString(any());
verify(so4, times(1)).writeString(any());
}

public void testParse() throws IOException {
Expand Down

0 comments on commit a083b32

Please sign in to comment.