Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Add support for Bedrock Converse API (Anthropic Messages API,… #2929

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,3 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
}

forbiddenPatterns {
exclude '**/*.pdf'
exclude '**/*.jpg'
}

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField(),
params.getLlmMessages()
params.getLlmResponseField()
),
null,
llmQuestion,
Expand All @@ -203,8 +202,7 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField(),
params.getLlmMessages()
params.getLlmResponseField()
),
conversationId,
llmQuestion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.ParseField;
Expand All @@ -32,7 +30,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -84,8 +81,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");

private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");

public static final int SIZE_NULL_VALUE = -1;

static {
Expand All @@ -99,7 +94,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
}

@Setter
Expand Down Expand Up @@ -138,10 +132,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Setter
@Getter
private List<MessageBlock> llmMessages = new ArrayList<>();

public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -152,32 +142,6 @@ public GenerativeQAParameters(
Integer interactionSize,
Integer timeout,
String llmResponseField
) {
this(
conversationId,
llmModel,
llmQuestion,
systemPrompt,
userInstructions,
contextSize,
interactionSize,
timeout,
llmResponseField,
null
);
}

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout,
String llmResponseField,
List<MessageBlock> llmMessages
) {
this.conversationId = conversationId;
this.llmModel = llmModel;
Expand All @@ -192,9 +156,6 @@ public GenerativeQAParameters(
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
this.llmResponseField = llmResponseField;
if (llmMessages != null) {
this.llmMessages.addAll(llmMessages);
}
}

public GenerativeQAParameters(StreamInput input) throws IOException {
Expand All @@ -207,7 +168,6 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();
this.llmMessages.addAll(input.readList(MessageBlock::new));
}

@Override
Expand All @@ -221,8 +181,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
}

@Override
Expand All @@ -238,7 +197,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);
out.writeList(llmMessages);
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
Expand All @@ -265,8 +223,4 @@ public boolean equals(Object o) {
&& (this.timeout == other.getTimeout())
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
}

public void setMessageBlock(List<MessageBlock> blockList) {
this.llmMessages = blockList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,4 @@ public class ChatCompletionInput {
private String userInstructions;
private Llm.ModelProvider modelProvider;
private String llmResponseField;
private List<MessageBlock> llmMessages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
Expand Down Expand Up @@ -112,15 +113,14 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
// log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
|| chatCompletionInput.getLlmResponseField() != null) {
Expand All @@ -136,19 +136,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
chatCompletionInput.getContexts()
)
);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
// Bedrock Converse API does not include the system prompt as part of the Messages block.
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
null,
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: "
Expand All @@ -157,6 +144,7 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
);
}

// log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

Expand Down Expand Up @@ -196,20 +184,6 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
Map output = (Map) dataAsMap.get("output");
Map message = (Map) output.get("message");
if (message != null) {
List content = (List) message.get("content");
String answer = (String) ((Map) content.get(0)).get("text");
answers.add(answer);
} else {
Map error = (Map) output.get("error");
if (error == null) {
throw new RuntimeException("Unexpected output: " + output);
}
errors.add((String) error.get("message"));
}
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ public interface Llm {
enum ModelProvider {
OPENAI,
BEDROCK,
COHERE,
BEDROCK_CONVERSE
COHERE
}

void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public class LlmIOUtil {

public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
Expand All @@ -50,8 +49,7 @@ public static ChatCompletionInput createChatCompletionInput(
chatHistory,
contexts,
timeoutInSeconds,
llmResponseField,
null
llmResponseField
);
}

Expand All @@ -63,8 +61,7 @@ public static ChatCompletionInput createChatCompletionInput(
List<Interaction> chatHistory,
List<String> contexts,
int timeoutInSeconds,
String llmResponseField,
List<MessageBlock> llmMessages
String llmResponseField
) {
Llm.ModelProvider provider = null;
if (llmResponseField == null) {
Expand All @@ -74,8 +71,6 @@ public static ChatCompletionInput createChatCompletionInput(
provider = Llm.ModelProvider.BEDROCK;
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.COHERE;
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
}
}
}
Expand All @@ -88,8 +83,7 @@ public static ChatCompletionInput createChatCompletionInput(
systemPrompt,
userInstructions,
provider,
llmResponseField,
llmMessages
llmResponseField
);
}
}
Loading
Loading