Skip to content

Commit

Permalink
Add request level parameters for system_prompt and user_instructions.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Mar 20, 2024
1 parent 7c7330d commit a066434
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d\n"
Expand All @@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"memory_id\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d\n"
Expand Down Expand Up @@ -283,6 +287,8 @@ public void testBM25WithOpenAI() throws Exception {
requestParameters.match = "president";
requestParameters.llmModel = OPENAI_MODEL;
requestParameters.llmQuestion = "who is lincoln";
requestParameters.systemPrompt = "You are great at answering questions";
requestParameters.userInstructions = "Follow my instructions as best you can";
requestParameters.contextSize = 5;
requestParameters.interactionSize = 5;
requestParameters.timeout = 60;
Expand Down Expand Up @@ -502,6 +508,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
requestParameters.timeout
Expand All @@ -516,6 +524,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.conversationId,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
requestParameters.timeout
Expand Down Expand Up @@ -556,6 +566,8 @@ static class SearchRequestParameters {
String match;
String llmModel;
String llmQuestion;
String systemPrompt;
String userInstructions;
int contextSize;
int interactionSize;
int timeout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
List<String> searchResults = getSearchResults(response, topN);

// See if the prompt is being overridden at the request level.
String effectiveSystemPrompt = systemPrompt;
String effectiveUserInstructions = userInstructions;
if (params.getSystemPrompt() != null) {
effectiveSystemPrompt = params.getSystemPrompt();
}
if (params.getUserInstructions() != null) {
effectiveUserInstructions = params.getUserInstructions();
}

start = Instant.now();
try {
ChatCompletionOutput output = llm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -70,13 +71,19 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// from a remote inference endpoint before timing out the request.
private static final ParseField TIMEOUT = new ParseField("timeout");

private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);

private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);

public static final int SIZE_NULL_VALUE = -1;

static {
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
Expand Down Expand Up @@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private Integer timeout;

@Setter
@Getter
private String systemPrompt;

@Setter
@Getter
private String userInstructions;

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout
Expand All @@ -121,6 +138,8 @@ public GenerativeQAParameters(
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
Expand All @@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.conversationId = input.readOptionalString();
this.llmModel = input.readOptionalString();
this.llmQuestion = input.readString();
this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
this.contextSize = input.readInt();
this.interactionSize = input.readInt();
this.timeout = input.readInt();
Expand All @@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
.field(LLM_MODEL.getPreferredName(), this.llmModel)
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout);
Expand All @@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException {

Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
out.writeString(llmQuestion);
out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
out.writeInt(contextSize);
out.writeInt(interactionSize);
out.writeInt(timeout);
Expand All @@ -175,6 +200,8 @@ public boolean equals(Object o) {
return Objects.equals(this.conversationId, other.getConversationId())
&& Objects.equals(this.llmModel, other.getLlmModel())
&& Objects.equals(this.llmQuestion, other.getLlmQuestion())
&& Objects.equals(this.systemPrompt, other.getSystemPrompt())
&& Objects.equals(this.userInstructions, other.getUserInstructions())
&& (this.contextSize == other.getContextSize())
&& (this.interactionSize == other.getInteractionSize())
&& (this.timeout == other.getTimeout());
Expand Down
Loading

0 comments on commit a066434

Please sign in to comment.