From 212e782bfe3d2dbf1994c684fa9964ba9a0f5cc6 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 22 Jan 2025 15:31:48 -0800 Subject: [PATCH] Handle pagination_depth when from =0 (#1132) (#1136) * Handle pagination_depth when from =0 Signed-off-by: Varun Jain * Add changelog Signed-off-by: Varun Jain * Remove unecessary logs Signed-off-by: Varun Jain --------- Signed-off-by: Varun Jain (cherry picked from commit 3dbdcbaa97db3df4cf23974bdbc6d78175fb6a2e) Co-authored-by: Varun Jain --- CHANGELOG.md | 1 + .../query/HybridQueryBuilder.java | 5 +-- .../search/query/HybridCollectorManager.java | 17 +++++--- .../query/HybridQueryBuilderTests.java | 1 + .../neuralsearch/query/HybridQueryIT.java | 35 +--------------- .../query/HybridCollectorManagerTests.java | 41 ++++++++++++++++++- 6 files changed, 55 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b5c2ca0d..d265fa9f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043)) - Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)). - Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062)) +- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132)) ### Infrastructure - Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852)) - Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835)) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index bea94e603..c8737b94c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -58,7 +58,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); String queryName = null; @@ -324,7 +323,7 @@ private Collection toQueries(Collection queryBuilders, Quer return queries; } - private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { + private static void validatePaginationDepth(final Integer paginationDepth, final QueryShardContext queryShardContext) { if (Objects.isNull(paginationDepth)) { return; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 3c6a7271f..a7e4d9f82 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -485,14 +485,19 @@ private ReduceableSearchResult reduceSearchResults(final List 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request")); } - log.info("pagination_depth is {}", paginationDepth); - return paginationDepth; + + if (Objects.nonNull(paginationDepth)) { + return paginationDepth; + } + + // Switch to from+size retrieval size during standard hybrid query execution where from is 0. + return searchContext.size(); } /** diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index a6cf4d29e..34734c27f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -398,6 +398,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { .endObject() .endObject() .endArray() + .field("pagination_depth", 10) .endObject(); NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index c3087a1e4..91c9abb20 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -870,40 +870,6 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc assertEquals(RELATION_EQUAL_TO, total.get("relation")); } - @SneakyThrows - public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() { - try { - updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); - hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyMatchAll, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - null, - null, - null, - false, - null, - 2 - ); - - assertEquals(2, getHitCount(searchResponseAsMap)); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(4, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); - } - } - @SneakyThrows public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() { try { @@ -912,6 +878,7 @@ public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() { createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); ResponseException responseException = assertThrows( ResponseException.class, diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index f6948e3e3..306206e0c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -439,7 +439,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build(); HybridQuery hybridQueryWithMatchAll = new HybridQuery( List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), @@ -633,7 +633,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), @@ -1169,4 +1169,41 @@ public void testScrollWithHybridQuery_thenFail() { illegalArgumentException.getMessage() ); } + + @SneakyThrows + public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + // From >0 + when(searchContext.from()).thenReturn(5); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + + HybridQuery hybridQuery = new HybridQuery( + List.of(termSubQuery.toQuery(mockQueryShardContext)), + HybridQueryContext.builder().build() // pagination_depth is set to null + ); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "pagination_depth param is missing in the search request"), + illegalArgumentException.getMessage() + ); + } }