Skip to content

Commit

Permalink
Add support for Bedrock Converse API (Anthropic Messages API, Claude …
Browse files Browse the repository at this point in the history
…3.5 Sonnet) (opensearch-project#2851)

* Add support for Anthropic Message API (Issue 2826)

Signed-off-by: Austin Lee <[email protected]>

* Fix a bug.

Signed-off-by: Austin Lee <[email protected]>

* Add unit tests, improve coverage, clean up code.

Signed-off-by: Austin Lee <[email protected]>

* Allow pdf and jpg files for IT tests for multimodel conversation API testing.

Signed-off-by: Austin Lee <[email protected]>

* Fix spotless check issues.

Signed-off-by: Austin Lee <[email protected]>

* Update IT to work with session tokens.

Signed-off-by: Austin Lee <[email protected]>

* Fix MLRAGSearchProcessorIT not to extend RestMLRemoteInferenceIT.

Signed-off-by: Austin Lee <[email protected]>

* Use suite specific model group name.

Signed-off-by: Austin Lee <[email protected]>

* Disable tests that require futher investigation.

Signed-off-by: Austin Lee <[email protected]>

* Skip two additional tests with time-outs.

Signed-off-by: Austin Lee <[email protected]>

* Restore a change from RestMLRemoteInferenceIT.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Sep 6, 2024
1 parent eca963f commit 17e81ae
Show file tree
Hide file tree
Showing 18 changed files with 1,869 additions and 52 deletions.
5 changes: 5 additions & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,8 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
}

forbiddenPatterns {
exclude '**/*.pdf'
exclude '**/*.jpg'
}

Large diffs are not rendered by default.

Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
params.getLlmResponseField(),
params.getLlmMessages()
),
null,
llmQuestion,
Expand All @@ -202,7 +203,8 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
params.getLlmResponseField(),
params.getLlmMessages()
),
conversationId,
llmQuestion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.ParseField;
Expand All @@ -30,6 +32,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -81,6 +84,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");

private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");

public static final int SIZE_NULL_VALUE = -1;

static {
Expand All @@ -94,6 +99,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
}

@Setter
Expand Down Expand Up @@ -132,6 +138,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Setter
@Getter
private List<MessageBlock> llmMessages = new ArrayList<>();

public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -142,6 +152,32 @@ public GenerativeQAParameters(
Integer interactionSize,
Integer timeout,
String llmResponseField
) {
this(
conversationId,
llmModel,
llmQuestion,
systemPrompt,
userInstructions,
contextSize,
interactionSize,
timeout,
llmResponseField,
null
);
}

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout,
String llmResponseField,
List<MessageBlock> llmMessages
) {
this.conversationId = conversationId;
this.llmModel = llmModel;
Expand All @@ -156,6 +192,9 @@ public GenerativeQAParameters(
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
this.llmResponseField = llmResponseField;
if (llmMessages != null) {
this.llmMessages.addAll(llmMessages);
}
}

public GenerativeQAParameters(StreamInput input) throws IOException {
Expand All @@ -168,6 +207,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();
this.llmMessages.addAll(input.readList(MessageBlock::new));
}

@Override
Expand All @@ -181,7 +221,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
}

@Override
Expand All @@ -197,6 +238,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);
out.writeList(llmMessages);
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
Expand All @@ -223,4 +265,8 @@ public boolean equals(Object o) {
&& (this.timeout == other.getTimeout())
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
}

public void setMessageBlock(List<MessageBlock> blockList) {
this.llmMessages = blockList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ public class ChatCompletionInput {
private String userInstructions;
private Llm.ModelProvider modelProvider;
private String llmResponseField;
private List<MessageBlock> llmMessages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
Expand Down Expand Up @@ -113,14 +112,15 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
// log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
|| chatCompletionInput.getLlmResponseField() != null) {
Expand All @@ -136,6 +136,19 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
chatCompletionInput.getContexts()
)
);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
// Bedrock Converse API does not include the system prompt as part of the Messages block.
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
null,
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: "
Expand All @@ -144,7 +157,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
);
}

// log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

Expand Down Expand Up @@ -184,6 +196,20 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
Map output = (Map) dataAsMap.get("output");
Map message = (Map) output.get("message");
if (message != null) {
List content = (List) message.get("content");
String answer = (String) ((Map) content.get(0)).get("text");
answers.add(answer);
} else {
Map error = (Map) output.get("error");
if (error == null) {
throw new RuntimeException("Unexpected output: " + output);
}
errors.add((String) error.get("message"));
}
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public interface Llm {
enum ModelProvider {
OPENAI,
BEDROCK,
COHERE
COHERE,
BEDROCK_CONVERSE
}

void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class LlmIOUtil {

public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
Expand All @@ -49,7 +50,8 @@ public static ChatCompletionInput createChatCompletionInput(
chatHistory,
contexts,
timeoutInSeconds,
llmResponseField
llmResponseField,
null
);
}

Expand All @@ -61,7 +63,8 @@ public static ChatCompletionInput createChatCompletionInput(
List<Interaction> chatHistory,
List<String> contexts,
int timeoutInSeconds,
String llmResponseField
String llmResponseField,
List<MessageBlock> llmMessages
) {
Llm.ModelProvider provider = null;
if (llmResponseField == null) {
Expand All @@ -71,6 +74,8 @@ public static ChatCompletionInput createChatCompletionInput(
provider = Llm.ModelProvider.BEDROCK;
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.COHERE;
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
}
}
}
Expand All @@ -83,7 +88,8 @@ public static ChatCompletionInput createChatCompletionInput(
systemPrompt,
userInstructions,
provider,
llmResponseField
llmResponseField,
llmMessages
);
}
}
Loading

0 comments on commit 17e81ae

Please sign in to comment.