From 85e0c4145bfd4512dc085702b4d92eef46d3c464 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Mon, 27 Jan 2025 14:57:46 -0800 Subject: [PATCH] Cherry-pick BWC fix for system prompt and user instructions (#3437) * BWC (rag processor): add version control for newly added request params (#3125) (#3364) * gradle spotless Signed-off-by: Pavan Yekbote --------- Signed-off-by: Pavan Yekbote Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> (cherry picked from commit f4d4481f3da33b8543023b777a51f8e8c70802e8) --- .../ext/GenerativeQAParameters.java | 27 ++++++++++++++----- .../ext/GenerativeQAParamExtBuilderTests.java | 14 ++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 0cf51a8549..9258e790f8 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -90,6 +90,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { static final Version MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES = CommonValue.VERSION_2_18_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS = CommonValue.VERSION_2_13_0; + @Setter @Getter private String conversationId; @@ -200,16 +202,23 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.llmQuestion = input.readString(); } - this.systemPrompt = input.readOptionalString(); - this.userInstructions = input.readOptionalString(); + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + this.systemPrompt = input.readOptionalString(); + this.userInstructions = input.readOptionalString(); + } + this.contextSize = input.readInt(); this.interactionSize = input.readInt(); this.timeout = input.readInt(); - this.llmResponseField = input.readOptionalString(); + + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + this.llmResponseField = input.readOptionalString(); + } if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) { this.llmMessages.addAll(input.readList(MessageBlock::new)); } + } @Override @@ -272,12 +281,18 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(llmQuestion); } - out.writeOptionalString(systemPrompt); - out.writeOptionalString(userInstructions); + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + out.writeOptionalString(systemPrompt); + out.writeOptionalString(userInstructions); + } + out.writeInt(contextSize); out.writeInt(interactionSize); out.writeInt(timeout); - out.writeOptionalString(llmResponseField); + + if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS)) { + out.writeOptionalString(llmResponseField); + } if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) { out.writeList(llmMessages); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 00ae5f27b7..4d63a0157f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -121,6 +121,7 @@ public void testMiscMethods() throws IOException { assertNotEquals(builder1, builder2); assertNotEquals(builder1.hashCode(), builder2.hashCode()); + // BWC test for bedrock converse params StreamOutput so1 = mock(StreamOutput.class); when(so1.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES); builder1.writeTo(so1); @@ -130,6 +131,19 @@ public void testMiscMethods() throws IOException { when(so2.getVersion()).thenReturn(Version.V_2_17_0); builder1.writeTo(so2); verify(so2, times(5)).writeOptionalString(any()); + + // BWC test for system prompt and instructions + StreamOutput so3 = mock(StreamOutput.class); + when(so3.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_PROMPT_AND_INSTRUCTIONS); + builder1.writeTo(so3); + verify(so3, times(5)).writeOptionalString(any()); + verify(so3, times(1)).writeString(any()); + + StreamOutput so4 = mock(StreamOutput.class); + when(so4.getVersion()).thenReturn(Version.V_2_12_0); + builder1.writeTo(so4); + verify(so4, times(2)).writeOptionalString(any()); + verify(so4, times(1)).writeString(any()); } public void testParse() throws IOException {