From b3b2f2d50d68286fe4b2296d7f77fa5ce93325c3 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 19 Dec 2023 16:35:15 -0800 Subject: [PATCH] Add request level parameters for system_prompt and user_instructions. Signed-off-by: Austin Lee --- .../ml/rest/RestMLRAGSearchProcessorIT.java | 12 ++++ .../GenerativeQAResponseProcessor.java | 12 ++++ .../ext/GenerativeQAParameters.java | 27 ++++++++ .../GenerativeQAResponseProcessorTests.java | 44 +++++++++++-- .../ext/GenerativeQAParamExtBuilderTests.java | 25 +++++--- .../ext/GenerativeQAParametersTests.java | 62 ++++++++++++++++--- 6 files changed, 163 insertions(+), 19 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index f9201c15d6..4bdef055bd 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" + " \"llm_question\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d\n" @@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " \"llm_model\": \"%s\",\n" + " \"llm_question\": \"%s\",\n" + " \"memory_id\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d\n" @@ -308,6 +312,8 @@ public void testBM25WithOpenAI() throws Exception { requestParameters.match = "president"; requestParameters.llmModel = OPENAI_MODEL; requestParameters.llmQuestion = "who is lincoln"; + requestParameters.systemPrompt = "You are great at answering questions"; + requestParameters.userInstructions = "Follow my instructions as best you can"; requestParameters.contextSize = 5; requestParameters.interactionSize = 5; requestParameters.timeout = 60; @@ -527,6 +533,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.match, requestParameters.llmModel, requestParameters.llmQuestion, + requestParameters.systemPrompt, + requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, requestParameters.timeout @@ -541,6 +549,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.llmModel, requestParameters.llmQuestion, requestParameters.conversationId, + requestParameters.systemPrompt, + requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, requestParameters.timeout @@ -581,6 +591,8 @@ static class SearchRequestParameters { String match; String llmModel; String llmQuestion; + String systemPrompt; + String userInstructions; int contextSize; int interactionSize; int timeout; diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index e3e86320dd..34f5ec5c5a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -143,6 +143,18 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } List searchResults = getSearchResults(response, topN); + // See if the prompt is being overridden at the request level. + String effectiveSystemPrompt = systemPrompt; + String effectiveUserInstructions = userInstructions; + if (params.getSystemPrompt() != null) { + effectiveSystemPrompt = params.getSystemPrompt(); + } + if (params.getUserInstructions() != null) { + effectiveUserInstructions = params.getUserInstructions(); + } + log.info("system_prompt: {}", effectiveSystemPrompt); + log.info("user_instructions: {}", effectiveUserInstructions); + start = Instant.now(); try { ChatCompletionOutput output = llm diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 19b80a838f..6c0ea2d0f5 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; import com.google.common.base.Preconditions; @@ -70,6 +71,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { // from a remote inference endpoint before timing out the request. private static final ParseField TIMEOUT = new ParseField("timeout"); + private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT); + + private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS); + public static final int SIZE_NULL_VALUE = -1; static { @@ -77,6 +82,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); + PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT); + PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS); PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); @@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private Integer timeout; + @Setter + @Getter + private String systemPrompt; + + @Setter + @Getter + private String userInstructions; + public GenerativeQAParameters( String conversationId, String llmModel, String llmQuestion, + String systemPrompt, + String userInstructions, Integer contextSize, Integer interactionSize, Integer timeout @@ -121,6 +138,8 @@ public GenerativeQAParameters( // for question rewriting. Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); this.llmQuestion = llmQuestion; + this.systemPrompt = systemPrompt; + this.userInstructions = userInstructions; this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize; this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout; @@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.conversationId = input.readOptionalString(); this.llmModel = input.readOptionalString(); this.llmQuestion = input.readString(); + this.systemPrompt = input.readOptionalString(); + this.userInstructions = input.readOptionalString(); this.contextSize = input.readInt(); this.interactionSize = input.readInt(); this.timeout = input.readInt(); @@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params .field(CONVERSATION_ID.getPreferredName(), this.conversationId) .field(LLM_MODEL.getPreferredName(), this.llmModel) .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) + .field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt) + .field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions) .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) .field(TIMEOUT.getPreferredName(), this.timeout); @@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException { Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); out.writeString(llmQuestion); + out.writeOptionalString(systemPrompt); + out.writeOptionalString(userInstructions); out.writeInt(contextSize); out.writeInt(interactionSize); out.writeInt(timeout); @@ -175,6 +200,8 @@ public boolean equals(Object o) { return Objects.equals(this.conversationId, other.getConversationId()) && Objects.equals(this.llmModel, other.getLlmModel()) && Objects.equals(this.llmQuestion, other.getLlmQuestion()) + && Objects.equals(this.systemPrompt, other.getSystemPrompt()) + && Objects.equals(this.userInstructions, other.getUserInstructions()) && (this.contextSize == other.getContextSize()) && (this.interactionSize == other.getInteractionSize()) && (this.timeout == other.getTimeout()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index 6d6e3e5c5d..2610a92ede 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -106,7 +106,16 @@ public void testProcessResponseNoSearchHits() throws Exception { SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -170,7 +179,16 @@ public void testProcessResponse() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_promt", + "user_insturctions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -245,7 +263,16 @@ public void testProcessResponseSmallerContextSize() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); int contextSize = 5; - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + contextSize, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -319,7 +346,16 @@ public void testProcessResponseMissingContextField() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 5aeb1e804f..d8748bb2f4 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -39,7 +39,16 @@ public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { public void testCtor() throws IOException { GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); - GenerativeQAParameters parameters = new GenerativeQAParameters("conversation_id", "model_id", "question", null, null, null); + GenerativeQAParameters parameters = new GenerativeQAParameters( + "conversation_id", + "model_id", + "question", + "system_promtp", + "user_instructions", + null, + null, + null + ); builder.setParams(parameters); assertEquals(parameters, builder.getParams()); @@ -79,8 +88,8 @@ public int read() throws IOException { } public void testMiscMethods() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); - GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); + GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", "s", "u", null, null, null); GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); builder1.setParams(param1); @@ -92,7 +101,7 @@ public void testMiscMethods() throws IOException { StreamOutput so = mock(StreamOutput.class); builder1.writeTo(so); - verify(so, times(2)).writeOptionalString(any()); + verify(so, times(4)).writeOptionalString(any()); verify(so, times(1)).writeString(any()); } @@ -105,7 +114,7 @@ public void testParse() throws IOException { } public void testXContentRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); @@ -120,7 +129,7 @@ public void testXContentRoundTrip() throws IOException { } public void testXContentRoundTripAllValues() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); @@ -131,7 +140,7 @@ public void testXContentRoundTripAllValues() throws IOException { } public void testStreamRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); BytesStreamOutput bso = new BytesStreamOutput(); @@ -145,7 +154,7 @@ public void testStreamRoundTrip() throws IOException { } public void testStreamRoundTripAllValues() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); BytesStreamOutput bso = new BytesStreamOutput(); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index 600b1c7a19..659a7b4e1f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -36,7 +36,16 @@ public class GenerativeQAParametersTests extends OpenSearchTestCase { public void testGenerativeQAParameters() { - GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "conversation_id", + "llm_model", + "llm_question", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); @@ -98,6 +107,8 @@ public void testWriteTo() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; + String systemPrompt = "s"; + String userInstructions = "u"; int contextSize = 1; int interactionSize = 2; int timeout = 10; @@ -105,6 +116,8 @@ public void testWriteTo() throws IOException { conversationId, llmModel, llmQuestion, + systemPrompt, + userInstructions, contextSize, interactionSize, timeout @@ -112,7 +125,7 @@ public void testWriteTo() throws IOException { StreamOutput output = new DummyStreamOutput(); parameters.writeTo(output); List actual = ((DummyStreamOutput) output).getList(); - assertEquals(3, actual.size()); + assertEquals(5, actual.size()); assertEquals(conversationId, actual.get(0)); assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); @@ -126,12 +139,32 @@ public void testMisc() { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); + String systemPrompt = "s"; + String userInstructions = "u"; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + systemPrompt, + userInstructions, + null, + null, + null + ); assertNotEquals(parameters, null); assertNotEquals(parameters, "foo"); - assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null)); - assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion, null, null, null)); - assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion, null, null, null)); + assertEquals( + parameters, + new GenerativeQAParameters(conversationId, llmModel, llmQuestion, systemPrompt, userInstructions, null, null, null) + ); + assertNotEquals( + parameters, + new GenerativeQAParameters("", llmModel, llmQuestion, systemPrompt, userInstructions, null, null, null) + ); + assertNotEquals( + parameters, + new GenerativeQAParameters(conversationId, "", llmQuestion, systemPrompt, userInstructions, null, null, null) + ); // assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "", null)); } @@ -139,7 +172,18 @@ public void testToXConent() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); + String systemPrompt = "s"; + String userInstructions = "u"; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + systemPrompt, + userInstructions, + null, + null, + null + ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); XContentGenerator generator = mock(XContentGenerator.class); @@ -152,6 +196,8 @@ public void testToXConentAllOptionalParameters() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; + String systemPrompt = "s"; + String userInstructions = "u"; int contextSize = 1; int interactionSize = 2; int timeout = 10; @@ -159,6 +205,8 @@ public void testToXConentAllOptionalParameters() throws IOException { conversationId, llmModel, llmQuestion, + systemPrompt, + userInstructions, contextSize, interactionSize, timeout