Skip to content

Commit

Permalink
Fix: Gracefully handle error when generative_qa_parameters is not pro…
Browse files Browse the repository at this point in the history
…vided (opensearch-project#3100)

* fix: gracefully handle error when generative_qa_parameters is not provided

Signed-off-by: Pavan Yekbote <[email protected]>

* fix: spotless apply

Signed-off-by: Pavan Yekbote <[email protected]>

* docs: adding documentation link to error message

Signed-off-by: Pavan Yekbote <[email protected]>

* tests: adding UT to test null params

Signed-off-by: Pavan Yekbote <[email protected]>

---------

Signed-off-by: Pavan Yekbote <[email protected]>
  • Loading branch information
pyek-bot committed Oct 15, 2024
1 parent 45e93d7 commit 5ef4711
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants {
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled.";

public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found."
+ " Please provide ext.generative_qa_parameters to proceed."
+ " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag";
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.searchpipelines.questionanswering.generative;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -115,6 +116,9 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
if (params == null) {
throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);
}

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor.IllegalArgumentMessage;

import java.time.Instant;
Expand Down Expand Up @@ -461,4 +462,52 @@ public void testProcessResponseNullValueInteractions() throws Exception {

SearchResponse res = processor.processResponse(request, response);
}

public void testProcessResponseIllegalArgumentForNullParams() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);

Client client = mock(Client.class);
Map<String, Object> config = new HashMap<>();
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));

GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
client,
alwaysOn
).create(null, "tag", "desc", true, config, null);

ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
when(memoryClient.getInteractions(any(), anyInt()))
.thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null)));
processor.setMemoryClient(memoryClient);

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();

GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(null);
request.source(sourceBuilder);
sourceBuilder.ext(List.of(extBuilder));

int numHits = 10;
SearchHit[] hitsArray = new SearchHit[numHits];
for (int i = 0; i < numHits; i++) {
XContentBuilder sourceContent = JsonXContent
.contentBuilder()
.startObject()
.field("_id", String.valueOf(i))
.field("text", "passage" + i)
.field("title", "This is the title for document " + i)
.endObject();
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
}

SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);

SearchResponse res = processor.processResponse(request, response);
}
}

0 comments on commit 5ef4711

Please sign in to comment.