diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 75cc64a0d3..1942e31f39 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" + " \"llm_question\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d\n" @@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { + " \"llm_model\": \"%s\",\n" + " \"llm_question\": \"%s\",\n" + " \"memory_id\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d\n" @@ -283,6 +287,8 @@ public void testBM25WithOpenAI() throws Exception { requestParameters.match = "president"; requestParameters.llmModel = OPENAI_MODEL; requestParameters.llmQuestion = "who is lincoln"; + requestParameters.systemPrompt = "You are great at answering questions"; + requestParameters.userInstructions = "Follow my instructions as best you can"; requestParameters.contextSize = 5; requestParameters.interactionSize = 5; requestParameters.timeout = 60; @@ -502,6 +508,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.match, requestParameters.llmModel, requestParameters.llmQuestion, + requestParameters.systemPrompt, + requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, requestParameters.timeout @@ -516,6 +524,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.llmModel, requestParameters.llmQuestion, requestParameters.conversationId, + requestParameters.systemPrompt, + requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, requestParameters.timeout @@ -556,6 +566,8 @@ static class SearchRequestParameters { String match; String llmModel; String llmQuestion; + String systemPrompt; + String userInstructions; int contextSize; int interactionSize; int timeout; diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index e3e86320dd..c2f0bc5351 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -143,6 +143,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } List searchResults = getSearchResults(response, topN); + // See if the prompt is being overridden at the request level. + String effectiveSystemPrompt = systemPrompt; + String effectiveUserInstructions = userInstructions; + if (params.getSystemPrompt() != null) { + effectiveSystemPrompt = params.getSystemPrompt(); + } + if (params.getUserInstructions() != null) { + effectiveUserInstructions = params.getUserInstructions(); + } + start = Instant.now(); try { ChatCompletionOutput output = llm diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 19b80a838f..6c0ea2d0f5 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; import com.google.common.base.Preconditions; @@ -70,6 +71,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { // from a remote inference endpoint before timing out the request. private static final ParseField TIMEOUT = new ParseField("timeout"); + private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT); + + private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS); + public static final int SIZE_NULL_VALUE = -1; static { @@ -77,6 +82,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); + PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT); + PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS); PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); @@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private Integer timeout; + @Setter + @Getter + private String systemPrompt; + + @Setter + @Getter + private String userInstructions; + public GenerativeQAParameters( String conversationId, String llmModel, String llmQuestion, + String systemPrompt, + String userInstructions, Integer contextSize, Integer interactionSize, Integer timeout @@ -121,6 +138,8 @@ public GenerativeQAParameters( // for question rewriting. Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); this.llmQuestion = llmQuestion; + this.systemPrompt = systemPrompt; + this.userInstructions = userInstructions; this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize; this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout; @@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.conversationId = input.readOptionalString(); this.llmModel = input.readOptionalString(); this.llmQuestion = input.readString(); + this.systemPrompt = input.readOptionalString(); + this.userInstructions = input.readOptionalString(); this.contextSize = input.readInt(); this.interactionSize = input.readInt(); this.timeout = input.readInt(); @@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params .field(CONVERSATION_ID.getPreferredName(), this.conversationId) .field(LLM_MODEL.getPreferredName(), this.llmModel) .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) + .field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt) + .field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions) .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) .field(TIMEOUT.getPreferredName(), this.timeout); @@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException { Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); out.writeString(llmQuestion); + out.writeOptionalString(systemPrompt); + out.writeOptionalString(userInstructions); out.writeInt(contextSize); out.writeInt(interactionSize); out.writeInt(timeout); @@ -175,6 +200,8 @@ public boolean equals(Object o) { return Objects.equals(this.conversationId, other.getConversationId()) && Objects.equals(this.llmModel, other.getLlmModel()) && Objects.equals(this.llmQuestion, other.getLlmQuestion()) + && Objects.equals(this.systemPrompt, other.getSystemPrompt()) + && Objects.equals(this.userInstructions, other.getUserInstructions()) && (this.contextSize == other.getContextSize()) && (this.interactionSize == other.getInteractionSize()) && (this.timeout == other.getTimeout()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index 6d6e3e5c5d..1b566e45cf 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -34,6 +34,7 @@ import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; +import org.opensearch.OpenSearchException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -106,7 +107,16 @@ public void testProcessResponseNoSearchHits() throws Exception { SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -170,7 +180,16 @@ public void testProcessResponse() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_promt", + "user_insturctions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -212,6 +231,90 @@ public void testProcessResponse() throws Exception { assertTrue(res instanceof GenerativeSearchResponse); } + public void testProcessResponseWithErrorFromLlm() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn( + List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ) + ); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_promt", + "user_insturctions", + null, + null, + null + ); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.isErrorOccurred()).thenReturn(true); + when(output.getErrors()).thenReturn(List.of("something bad has occurred.")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + SearchResponse res = processor.processResponse(request, response); + verify(llm).doChatCompletion(captor.capture()); + ChatCompletionInput input = captor.getValue(); + assertTrue(input instanceof ChatCompletionInput); + List passages = ((ChatCompletionInput) input).getContexts(); + assertEquals("passage0", passages.get(0)); + assertEquals("passage1", passages.get(1)); + assertEquals(numHits, passages.size()); + assertTrue(res instanceof GenerativeSearchResponse); + } + public void testProcessResponseSmallerContextSize() throws Exception { Client client = mock(Client.class); Map config = new HashMap<>(); @@ -245,7 +348,16 @@ public void testProcessResponseSmallerContextSize() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); int contextSize = 5; - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + contextSize, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -319,7 +431,16 @@ public void testProcessResponseMissingContextField() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "You are kind.", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -431,7 +552,16 @@ public void testProcessResponseNullValueInteractions() throws Exception { SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); int contextSize = 5; - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "llm_model", + "Question", + "You are kind.", + null, + contextSize, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -461,4 +591,128 @@ public void testProcessResponseNullValueInteractions() throws Exception { SearchResponse res = processor.processResponse(request, response); } + + public void testProcessResponseIllegalArgument() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("llm_model cannot be null."); + + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + int contextSize = 5; + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + null, + "Question", + "You are kind.", + null, + contextSize, + null, + null + ); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + // when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions")); + processor.setLlm(llm); + + SearchResponse res = processor.processResponse(request, response); + } + + public void testProcessResponseOpenSearchException() throws Exception { + exceptionRule.expect(OpenSearchException.class); + exceptionRule.expectMessage("GenerativeQAResponseProcessor failed in precessing response"); + + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + int contextSize = 5; + GenerativeQAParameters params = new GenerativeQAParameters( + "12345", + "model", + "Question", + "You are kind.", + null, + contextSize, + null, + null + ); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + when(llm.doChatCompletion(any())).thenThrow(new RuntimeException()); + processor.setLlm(llm); + + SearchResponse res = processor.processResponse(request, response); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 5aeb1e804f..d8748bb2f4 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -39,7 +39,16 @@ public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { public void testCtor() throws IOException { GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); - GenerativeQAParameters parameters = new GenerativeQAParameters("conversation_id", "model_id", "question", null, null, null); + GenerativeQAParameters parameters = new GenerativeQAParameters( + "conversation_id", + "model_id", + "question", + "system_promtp", + "user_instructions", + null, + null, + null + ); builder.setParams(parameters); assertEquals(parameters, builder.getParams()); @@ -79,8 +88,8 @@ public int read() throws IOException { } public void testMiscMethods() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); - GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); + GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", "s", "u", null, null, null); GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); builder1.setParams(param1); @@ -92,7 +101,7 @@ public void testMiscMethods() throws IOException { StreamOutput so = mock(StreamOutput.class); builder1.writeTo(so); - verify(so, times(2)).writeOptionalString(any()); + verify(so, times(4)).writeOptionalString(any()); verify(so, times(1)).writeString(any()); } @@ -105,7 +114,7 @@ public void testParse() throws IOException { } public void testXContentRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); @@ -120,7 +129,7 @@ public void testXContentRoundTrip() throws IOException { } public void testXContentRoundTripAllValues() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); @@ -131,7 +140,7 @@ public void testXContentRoundTripAllValues() throws IOException { } public void testStreamRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); BytesStreamOutput bso = new BytesStreamOutput(); @@ -145,7 +154,7 @@ public void testStreamRoundTrip() throws IOException { } public void testStreamRoundTripAllValues() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); BytesStreamOutput bso = new BytesStreamOutput(); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index 600b1c7a19..659a7b4e1f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -36,7 +36,16 @@ public class GenerativeQAParametersTests extends OpenSearchTestCase { public void testGenerativeQAParameters() { - GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question", null, null, null); + GenerativeQAParameters params = new GenerativeQAParameters( + "conversation_id", + "llm_model", + "llm_question", + "system_prompt", + "user_instructions", + null, + null, + null + ); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); @@ -98,6 +107,8 @@ public void testWriteTo() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; + String systemPrompt = "s"; + String userInstructions = "u"; int contextSize = 1; int interactionSize = 2; int timeout = 10; @@ -105,6 +116,8 @@ public void testWriteTo() throws IOException { conversationId, llmModel, llmQuestion, + systemPrompt, + userInstructions, contextSize, interactionSize, timeout @@ -112,7 +125,7 @@ public void testWriteTo() throws IOException { StreamOutput output = new DummyStreamOutput(); parameters.writeTo(output); List actual = ((DummyStreamOutput) output).getList(); - assertEquals(3, actual.size()); + assertEquals(5, actual.size()); assertEquals(conversationId, actual.get(0)); assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); @@ -126,12 +139,32 @@ public void testMisc() { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); + String systemPrompt = "s"; + String userInstructions = "u"; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + systemPrompt, + userInstructions, + null, + null, + null + ); assertNotEquals(parameters, null); assertNotEquals(parameters, "foo"); - assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null)); - assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion, null, null, null)); - assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion, null, null, null)); + assertEquals( + parameters, + new GenerativeQAParameters(conversationId, llmModel, llmQuestion, systemPrompt, userInstructions, null, null, null) + ); + assertNotEquals( + parameters, + new GenerativeQAParameters("", llmModel, llmQuestion, systemPrompt, userInstructions, null, null, null) + ); + assertNotEquals( + parameters, + new GenerativeQAParameters(conversationId, "", llmQuestion, systemPrompt, userInstructions, null, null, null) + ); // assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "", null)); } @@ -139,7 +172,18 @@ public void testToXConent() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); + String systemPrompt = "s"; + String userInstructions = "u"; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + systemPrompt, + userInstructions, + null, + null, + null + ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); XContentGenerator generator = mock(XContentGenerator.class); @@ -152,6 +196,8 @@ public void testToXConentAllOptionalParameters() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; + String systemPrompt = "s"; + String userInstructions = "u"; int contextSize = 1; int interactionSize = 2; int timeout = 10; @@ -159,6 +205,8 @@ public void testToXConentAllOptionalParameters() throws IOException { conversationId, llmModel, llmQuestion, + systemPrompt, + userInstructions, contextSize, interactionSize, timeout