diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index c96187a65..2c79c56e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -148,11 +148,6 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } - public void addSubQuery(final Query query) { - Objects.requireNonNull(subQueries, "collection of queries must not be null"); - subQueries.add(query); - } - /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index c838a7d8e..7b9224c39 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -71,7 +71,7 @@ public boolean searchWith( query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - validateHybridQuery(query); + validateQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } @@ -112,7 +112,7 @@ && mightBeWrappedHybridQuery(query) return query; } - private void validateHybridQuery(final Query query) { + private void validateQuery(final Query query) { if (query instanceof BooleanQuery) { List booleanClauses = ((BooleanQuery) query).clauses(); for (BooleanClause booleanClause : booleanClauses) { diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 61c6c13df..e3e57a141 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts + ) { + addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); + } + /** * Add a set of knn vectors and text to an index * @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam * @param vectors List of vectors corresponding to those fields * @param textFieldNames List of text fields to be added * @param texts List of text corresponding to those fields + * @param nestedFieldNames List of nested fields to be added + * @param nestedFields List of fields and values corresponding to those fields */ @SneakyThrows protected void addKnnDoc( @@ -430,7 +444,9 @@ protected void addKnnDoc( List vectorFieldNames, List vectors, List textFieldNames, - List texts + List texts, + List nestedFieldNames, + List> nestedFields ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -441,6 +457,16 @@ protected void addKnnDoc( for (int i = 0; i < textFieldNames.size(); i++) { builder.field(textFieldNames.get(i), texts.get(i)); } + + for (int i = 0; i < nestedFieldNames.size(); i++) { + builder.field(nestedFieldNames.get(i)); + builder.startObject(); + Map nestedValues = nestedFields.get(i); + for (Map.Entry entry : nestedValues.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } builder.endObject(); request.setJsonEntity(builder.toString()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 176628eb8..4c3ff2cb9 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -7,6 +7,7 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.query.QueryBuilders.matchQuery; import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -21,11 +22,13 @@ import lombok.SneakyThrows; +import org.apache.lucene.search.join.ScoreMode; import org.junit.After; import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -51,9 +54,14 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String NESTED_FIELD_1 = "firstname"; + private static final String NESTED_FIELD_2 = "lastname"; + private static final String NESTED_FIELD_1_VALUE = "john"; + private static final String NESTED_FIELD_2_VALUE = "black"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -271,6 +279,39 @@ public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + NestedQueryBuilder nestedQueryBuilder = QueryBuilders.nestedQuery( + TEST_NESTED_TYPE_FIELD_NAME_1, + matchQuery(TEST_NESTED_TYPE_FIELD_NAME_1 + "." + NESTED_FIELD_1, NESTED_FIELD_1_VALUE), + ScoreMode.Total + ); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(nestedQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(1, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(getMaxScore(searchResponseAsMap).get() > 0); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(1, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -344,13 +385,23 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { indexName, buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), - List.of("user"), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), 1 ), "" ); addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + List.of(), + List.of(), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index b8dc8fa4a..4e00cf0f5 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -5,8 +5,10 @@ package org.opensearch.neuralsearch.search.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -14,6 +16,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAX_NESTED_SUBQUERY_LIMIT; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; @@ -21,6 +24,7 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.Set; import lombok.SneakyThrows; @@ -30,6 +34,8 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -38,11 +44,13 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -83,7 +91,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -126,6 +135,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -152,7 +162,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -196,6 +207,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -205,8 +217,6 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); - MapperService mapperService = mock(MapperService.class); - when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -315,7 +325,8 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -365,6 +376,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -409,6 +421,295 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); + + Query query = boolQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("user"); + b.field("type", "nested"); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); + when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD) + .add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testBoolQuery_whenTooManyNestedLevels_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + Query query = createNestedBoolQuery( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2).toQuery(mockQueryShardContext), + MAX_NESTED_SUBQUERY_LIMIT + 1 + ); + + when(searchContext.query()).thenReturn(query); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("reached max nested query limit, cannot process quer") + ); + + releaseResources(directory, w, reader); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); @@ -452,4 +753,15 @@ private List getSubQueryResultsForSingleShard(final TopDocs topDocs) { } return topDocsList; } + + private BooleanQuery createNestedBoolQuery(final Query query1, final Query query2, int level) { + if (level == 0) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(query1, BooleanClause.Occur.SHOULD).add(query2, BooleanClause.Occur.SHOULD); + return builder.build(); + } + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(createNestedBoolQuery(query1, query2, level - 1), BooleanClause.Occur.MUST); + return builder.build(); + } }