Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Validate Disjuction query in HybridQueryPhaseSearcher #1144

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054))
- Support for builder constructor in Neural Query Builder ([#1047](https://github.com/opensearch-project/neural-search/pull/1047))
- Validate Disjunction query to avoid having nested hybrid query ([#1127](https://github.com/opensearch-project/neural-search/pull/1127))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -104,7 +105,7 @@
* }
* ]
* }
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
* TODO add similar validation for other compound type queries like constant_score, function_score etc.
*
* @param query query to validate
*/
Expand All @@ -114,6 +115,10 @@
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext));
}
} else if (query instanceof DisjunctionMaxQuery) {
for (Query disjunct : (DisjunctionMaxQuery) query) {
validateNestedDisJunctionQuery(disjunct, getMaxDepthLimit(searchContext));
}

Check warning on line 121 in src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java#L120-L121

Added lines #L120 - L121 were not covered by tests
}
}

Expand All @@ -135,6 +140,24 @@
}
}

private void validateNestedDisJunctionQuery(final Query query, final int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level <= 0) {
// ideally we should throw an error here but this code is on the main search workflow path and that might block
// execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect
// results in case hybrid query is wrapped into such dis_max query
log.error("reached max nested query limit, cannot process dis_max query with that many nested clauses");
return;

Check warning on line 152 in src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java#L151-L152

Added lines #L151 - L152 were not covered by tests
}
if (query instanceof DisjunctionMaxQuery) {
for (Query disjunct : (DisjunctionMaxQuery) query) {
validateNestedDisJunctionQuery(disjunct, level - 1);
}

Check warning on line 157 in src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java#L156-L157

Added lines #L156 - L157 were not covered by tests
}
}

Check warning on line 159 in src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java#L159

Added line #L159 was not covered by tests

private int getMaxDepthLimit(final SearchContext searchContext) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.DisMaxQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.remote.RemoteStoreEnums;
Expand Down Expand Up @@ -516,6 +517,104 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
MapperService mapperService = mock(MapperService.class);
when(mapperService.hasNested()).thenReturn(false);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = getIndexMetadata();
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

// Create a HybridQueryBuilder
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
hybridQueryBuilder.paginationDepth(10);

// Create a regular term query
TermQueryBuilder termQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);

// Create a disjunction query (OR) with the hybrid query and the term query
DisMaxQueryBuilder disjunctionMaxQueryBuilder = QueryBuilders.disMaxQuery().add(hybridQueryBuilder).add(termQuery);

Query query = disjunctionMaxQueryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> hybridQueryPhaseSearcher.searchWith(
searchContext,
contextIndexSearcher,
query,
collectors,
hasFilterCollector,
hasTimeout
)
);

org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
containsString("hybrid query must be a top level query and cannot be wrapped into other queries")
);

releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
Expand Down
Loading