From c2ddbab0401ead5beaeccd9acf227f4c19727b97 Mon Sep 17 00:00:00 2001 From: Isaac Johnson <114550967+Johnsonisaacn@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:44:07 -0700 Subject: [PATCH 01/16] Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) * initial commit of RRF Signed-off-by: Isaac Johnson Co-authored-by: Varun Jain Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 10 +- .../processor/NormalizationExecuteDTO.java | 35 +++ .../NormalizationProcessorWorkflow.java | 38 +-- .../processor/NormalizeScoresDTO.java | 26 ++ .../neuralsearch/processor/RRFProcessor.java | 140 ++++++++++ .../RRFScoreCombinationTechnique.java | 32 +++ .../combination/ScoreCombinationFactory.java | 4 +- .../combination/ScoreCombinationUtil.java | 5 +- .../factory/RRFProcessorFactory.java | 79 ++++++ .../L2ScoreNormalizationTechnique.java | 4 +- .../MinMaxScoreNormalizationTechnique.java | 4 +- .../RRFNormalizationTechnique.java | 106 ++++++++ .../ScoreNormalizationFactory.java | 18 +- .../ScoreNormalizationTechnique.java | 14 +- .../normalization/ScoreNormalizationUtil.java | 57 +++++ .../normalization/ScoreNormalizer.java | 15 +- .../plugin/NeuralSearchTests.java | 11 +- .../NormalizationProcessorTests.java | 6 +- .../NormalizationProcessorWorkflowTests.java | 100 ++++---- .../ScoreNormalizationTechniqueTests.java | 31 ++- .../RRFScoreCombinationTechniqueTests.java | 35 +++ .../ScoreCombinationFactoryTests.java | 8 + ....java => ScoreNormalizationUtilTests.java} | 2 +- .../NormalizationProcessorFactoryTests.java | 21 +- .../factory/RRFProcessorFactoryTests.java | 214 ++++++++++++++++ .../L2ScoreNormalizationTechniqueTests.java | 21 +- ...inMaxScoreNormalizationTechniqueTests.java | 19 +- .../RRFNormalizationTechniqueTests.java | 242 ++++++++++++++++++ .../ScoreNormalizationFactoryTests.java | 8 + .../query/OpenSearchQueryTestCase.java | 2 + .../query/HybridCollectorManagerTests.java | 1 - .../HybridQueryScoreDocsMergerTests.java | 2 - .../search/query/TopDocsMergerTests.java | 2 - 34 files changed, 1184 insertions(+), 129 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java rename src/test/java/org/opensearch/neuralsearch/processor/combination/{ScoreCombinationUtilTests.java => ScoreNormalizationUtilTests.java} (97%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6203e6e88..eb98378da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features - Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048)) +- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 1350a7963..f7ac5d19f 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -30,22 +30,24 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; -import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -157,7 +159,9 @@ public Map querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index db3747a13..51f30f842 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -22,14 +22,12 @@ import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; -import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; -import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -57,34 +55,14 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param request contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique, - final SearchPhaseContext searchPhaseContext - ) { - NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() - .querySearchResults(querySearchResults) - .fetchSearchResultOptional(fetchSearchResultOptional) - .normalizationTechnique(normalizationTechnique) - .combinationTechnique(combinationTechnique) - .explain(false) - .searchPhaseContext(searchPhaseContext) - .build(); - execute(request); - } - public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { List querySearchResults = request.getQuerySearchResults(); Optional fetchSearchResultOptional = request.getFetchSearchResultOptional(); - - // save original state - List unprocessedDocIds = unprocessedDocIds(querySearchResults); + List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); // pre-process data log.debug("Pre-process query results"); @@ -92,9 +70,15 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) explain(request, queryTopDocs); + // Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(request.getNormalizationTechnique()) + .build(); + // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique()); + scoreNormalizer.normalizeScores(normalizeScoresDTO); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java new file mode 100644 index 000000000..c932a157d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; + +import java.util.List; + +/** + * DTO object to hold data required for score normalization. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizeScoresDTO { + @NonNull + private List queryTopDocs; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java new file mode 100644 index 000000000..c8f78691a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import java.util.stream.Collectors; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Processor for implementing reciprocal rank fusion technique on post + * query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + * by using ranks from individual subqueries to calculate 'normalized' + * scores before combining results from subqueries into final results + */ +@Log4j2 +@AllArgsConstructor +public class RRFProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "score-ranker-processor"; + + @Getter + private final String tag; + @Getter + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (shouldSkipProcessor(searchPhaseResult)) { + log.debug("Query results are not compatible with RRF processor"); + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + + // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending + // on coming from NormalizationProcessor or RRFProcessor + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(false) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { + return true; + } + + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); + } + + /** + * Return true if results are from hybrid query. + * @param searchPhaseResult + * @return true if results are from hybrid query + */ + private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + // check for delimiter at the end of the score docs. + return Objects.nonNull(searchPhaseResult.queryResult()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) + && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 + && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); + } + + private List getQueryPhaseSearchResults( + final SearchPhaseResults results + ) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } + + private Optional getFetchSearchResults( + final SearchPhaseResults searchPhaseResults + ) { + Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); + return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java new file mode 100644 index 000000000..befe14dda --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import lombok.ToString; +import lombok.extern.log4j.Log4j2; + +import java.util.Map; + +@Log4j2 +/** + * Abstracts combination of scores based on reciprocal rank fusion algorithm + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + + // Not currently using weights for RRF, no need to modify or verify these params + public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) {} + + @Override + public float combine(final float[] scores) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 23d8e01be..1e560342a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -25,7 +25,9 @@ public class ScoreCombinationFactory { HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), + RRFScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 5f18baf09..99d0401d2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -26,6 +26,7 @@ public class ScoreCombinationUtil { public static final String PARAM_NAME_WEIGHTS = "weights"; private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + private static final float DELTA_FOR_WEIGHTS_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -117,7 +118,7 @@ protected void validateIfWeightsMatchScores(final float[] scores, final List weightsList) { - boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.of(0.0f, 1.0f).contains(weight)); if (isOutOfRange) { throw new IllegalArgumentException( String.format( @@ -128,7 +129,7 @@ private void validateWeights(final List weightsList) { ); } float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); - if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_WEIGHTS_ASSERTION)) { throw new IllegalArgumentException( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java new file mode 100644 index 000000000..fa4f39942 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; +import java.util.Objects; + +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.RRFScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.RRFNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +/** + * Factory class to instantiate RRF processor based on user provided input. + */ +@AllArgsConstructor +@Log4j2 +public class RRFProcessorFactory implements Processor.Factory { + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + public static final String PARAMETERS = "parameters"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + // assign defaults + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization( + RRFNormalizationTechnique.TECHNIQUE_NAME + ); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = readStringProperty( + RRFProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + // check for optional combination params + Map params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + } + log.info( + "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", + RRFProcessor.TYPE, + normalizationTechnique, + scoreCombinationTechnique + ); + return new RRFProcessor(tag, description, normalizationTechnique, scoreCombinationTechnique, normalizationProcessorWorkflow); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index e7fbf658c..c9472938d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; @@ -39,7 +40,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); // get l2 norms for each sub-query List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 3ca538f4e..da16d6c96 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -19,6 +19,7 @@ import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; @@ -43,7 +44,8 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); // do normalization using actual score and min and max scores for corresponding sub query for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java new file mode 100644 index 000000000..16ef83d05 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Locale; +import java.util.Set; + +import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.math.NumberUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; + +/** + * Abstracts calculation of rank scores for each document returned as part of + * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + public static final int DEFAULT_RANK_CONSTANT = 60; + public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); + private static final int MIN_RANK_CONSTANT = 1; + private static final int MAX_RANK_CONSTANT = 10_000; + private static final Range RANK_CONSTANT_RANGE = Range.of(MIN_RANK_CONSTANT, MAX_RANK_CONSTANT); + @ToString.Include + private final int rankConstant; + + public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + rankConstant = getRankConstant(params); + } + + /** + * Reciprocal Rank Fusion normalization technique + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant, which has a default value of 60 if not specified by user + * algorithm as follows, where document_n_score is the new score for each document in queryTopDocs + * and subquery_result_rank is the position in the array of documents returned for each subquery + * (j + 1 is used to adjust for 0 indexing) + * document_n_score = 1 / (rankConstant + subquery_result_rank) + * document scores are summed in combination step + */ + @Override + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (TopDocs topDocs : topDocsPerSubQuery) { + int docsCountPerSubQuery = topDocs.scoreDocs.length; + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + for (int j = 0; j < docsCountPerSubQuery; j++) { + // using big decimal approach to minimize error caused by floating point ops + // score = 1.f / (float) (rankConstant + j + 1)) + scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) + .floatValue(); + } + } + } + } + + private int getRankConstant(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + return DEFAULT_RANK_CONSTANT; + } + int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; + } + + private void validateRankConstant(final int rankConstant) { + if (!RANK_CONSTANT_RANGE.contains(rankConstant)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", + rankConstant + ) + ); + } + } + + public static int getParamAsInteger(final Map parameters, final String fieldName) { + try { + return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName)); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index ca6ad20d6..7c62893a5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -6,19 +6,24 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; /** * Abstracts creation of exact score normalization method based on technique name */ public class ScoreNormalizationFactory { + private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); - private final Map scoreNormalizationMethodsMap = Map.of( + private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique() + params -> new L2ScoreNormalizationTechnique(), + RRFNormalizationTechnique.TECHNIQUE_NAME, + params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); /** @@ -27,7 +32,12 @@ public class ScoreNormalizationFactory { * @return instance of ScoreNormalizationMethod for technique name */ public ScoreNormalizationTechnique createNormalization(final String technique) { + return createNormalization(technique, Map.of()); + } + + public ScoreNormalizationTechnique createNormalization(final String technique, final Map params) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 0b784c678..f8190a728 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -4,9 +4,7 @@ */ package org.opensearch.neuralsearch.processor.normalization; -import java.util.List; - -import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores in query search results. @@ -14,8 +12,12 @@ public interface ScoreNormalizationTechnique { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant that is only used in RRF technique */ - void normalize(final List queryTopDocs); + void normalize(final NormalizeScoresDTO normalizeScoresDTO); + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java new file mode 100644 index 000000000..ad24b0aaa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import lombok.extern.log4j.Log4j2; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Collection of utility methods for score combination technique classes + */ +@Log4j2 +class ScoreNormalizationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + String.join(",", supportedParams) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 67a17fda2..381ec9b9a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -12,17 +12,22 @@ import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; public class ScoreNormalizer { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries - * @param scoreNormalizationTechnique exact normalization technique that should be applied + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO used as data transfer object to pass in queryTopDocs, original query results + * from multiple shards and multiple sub-queries, scoreNormalizationTechnique exact normalization technique + * that should be applied, and nullable rankConstant that is only used in RRF technique */ - public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + public void normalizeScores(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + final ScoreNormalizationTechnique scoreNormalizationTechnique = normalizeScoresDTO.getNormalizationTechnique(); if (canQueryResultsBeNormalized(queryTopDocs)) { - scoreNormalizationTechnique.normalize(queryTopDocs); + scoreNormalizationTechnique.normalize(normalizeScoresDTO); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 9a969e71b..a4ad9f2d4 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -27,8 +27,10 @@ import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -143,12 +145,19 @@ public void testSearchPhaseResultsProcessors() { Map> searchPhaseResultsProcessors = plugin .getSearchPhaseResultsProcessors(searchParameters); assertNotNull(searchPhaseResultsProcessors); - assertEquals(1, searchPhaseResultsProcessors.size()); + assertEquals(2, searchPhaseResultsProcessors.size()); + // assert normalization processor conditions assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( NormalizationProcessor.TYPE ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + // assert rrf processor conditions + assertTrue(searchPhaseResultsProcessors.containsKey("score-ranker-processor")); + org.opensearch.search.pipeline.Processor.Factory rankingProcessor = searchPhaseResultsProcessors.get( + RRFProcessor.TYPE + ); + assertTrue(rankingProcessor instanceof RRFProcessorFactory); } public void testGetSettings() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 87dac8674..9f67327f1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -273,8 +273,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl ); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -329,8 +328,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 9969081a6..61828d822 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -81,13 +81,15 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -123,19 +125,22 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); SearchRequest searchRequest = mock(SearchRequest.class); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -194,13 +199,15 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -260,13 +267,15 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -318,16 +327,15 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ) - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDTO)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -376,13 +384,15 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index b2b0007f6..fe7192ecd 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; + import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -22,7 +23,11 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); - scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(List.of()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); } @SneakyThrows @@ -36,7 +41,11 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -68,7 +77,11 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -106,7 +119,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); @@ -177,7 +194,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); // shard one diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..daed466d3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; + +public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public RRFScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + private float RRF(List scores, List weights) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index b36a6b492..5ca534dac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -34,6 +34,14 @@ public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstan assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); } + public void testRRF_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("rrf"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof RRFScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java index 9e00e3833..009681116 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { +public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase { public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 9895d5b97..4cf1457ac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -167,22 +167,21 @@ public void testWeightsParams_whenInvalidValues_thenFail() { String tag = "tag"; String description = "description"; boolean ignoreFailure = false; + + // First value is always 0.5 + double first = 0.5; + // Second value is random between 0.3 and 1.0 + double second = 0.3 + (RandomizedTest.randomDouble() * 0.7); + // Third value is random between 0.3 and 1.0 + double third = 0.3 + (RandomizedTest.randomDouble() * 0.7); + // This ensures minimum sum of 1.1 (0.5 + 0.3 + 0.3) + Map config = new HashMap<>(); config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); config.put( COMBINATION_CLAUSE, new HashMap<>( - Map.of( - TECHNIQUE, - "arithmetic_mean", - PARAMETERS, - new HashMap<>( - Map.of( - "weights", - Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble()) - ) - ) - ) + Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(first, second, third)))) ) ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java new file mode 100644 index 000000000..3097402a0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE; + +public class RRFProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoValuesPassed_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testCombinationParams_whenValidValues_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNegative_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string")))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer")); + } + + @SneakyThrows + public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("provided combination technique is not supported")); + } + + @SneakyThrows + public void testInvalidTechniqueType_whenPassingNormalization_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + config.put( + NORMALIZATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, PARAMETERS, new HashMap<>(Map.of()))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + private static void assertRRFProcessor(SearchPhaseResultsProcessor searchPhaseResultsProcessor) { + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor); + RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor; + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index 734f9bb57..fc1663d75 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -13,9 +13,10 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** - * Abstracts normalization of scores based on min-max method + * Abstracts normalization of scores based on L2 method */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; @@ -37,7 +38,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -86,7 +91,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -163,7 +172,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index c7692b407..85c54ea3a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -35,7 +36,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -77,7 +82,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -135,7 +144,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java new file mode 100644 index 000000000..1e1089846 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -0,0 +1,242 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; + +/** + * Abstracts testing of normalization of scores based on RRF method + */ +public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { + static final int RANK_CONSTANT = 60; + private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ) + ), + false, + SEARCH_SHARD + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresQuery1 = { 0.5f, 0.2f }; + float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false, + SEARCH_SHARD + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresShard1Query1 = { 0.5f, 0.2f }; + float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ), + false, + SEARCH_SHARD + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false, + SEARCH_SHARD + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, rrfNorm(0)), new ScoreDoc(9, rrfNorm(1)) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, rrfNorm(3)), + new ScoreDoc(9, rrfNorm(4)), + new ScoreDoc(10, rrfNorm(5)), + new ScoreDoc(15, rrfNorm(6)) } + ) + ), + false, + SEARCH_SHARD + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + private float rrfNorm(int rank) { + // 1.0f / (float) (rank + RANK_CONSTANT + 1); + return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index d9dcd5540..cecdf8779 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -26,6 +26,14 @@ public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); } + public void testRRFNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("rrf"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof RRFNormalizationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index a1e8210e6..9c162ce11 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -53,6 +53,8 @@ public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { + protected static final float DELTA_FOR_ASSERTION = 0.001f; + protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { IndexMetadata meta = IndexMetadata.builder("index") .settings(Settings.builder().put("index.version.created", version)) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index e0d95f24e..f6948e3e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -85,7 +85,6 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; - private static final float DELTA_FOR_ASSERTION = 0.001f; protected static final String QUERY3 = "everyone"; @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 196014220..f91dae327 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -21,8 +21,6 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 9c2718687..2e064913f 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -27,8 +27,6 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { TopDocsMerger topDocsMerger = new TopDocsMerger(null); From 3a8dee6e0c8c82a32e4be303122b4250e2ec08b8 Mon Sep 17 00:00:00 2001 From: Isaac Johnson <114550967+Johnsonisaacn@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:44:07 -0700 Subject: [PATCH 02/16] Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) * initial commit of RRF Signed-off-by: Isaac Johnson Co-authored-by: Varun Jain Signed-off-by: Martin Gaievski --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4ee84ec8c..aa17e9ff3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -138,3 +138,4 @@ jobs: - name: Run build run: | ./gradlew precommit --parallel + From 2f886fab3a8457bd69abcc47f2a102d2ad4e09c8 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Dec 2024 08:59:26 -0800 Subject: [PATCH 03/16] Add integration and unit tests for missing RRF coverage (#997) * Initial unit test implementation Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan Signed-off-by: Martin Gaievski --- .../neuralsearch/processor/RRFProcessor.java | 14 +- .../MinMaxScoreNormalizationTechnique.java | 16 +- .../processor/RRFProcessorIT.java | 93 +++++++ .../processor/RRFProcessorTests.java | 226 ++++++++++++++++++ .../neuralsearch/BaseNeuralSearchIT.java | 28 +++ 5 files changed, 366 insertions(+), 11 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index c8f78691a..ca67f2d1c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -13,6 +13,7 @@ import java.util.Optional; import lombok.Getter; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -99,7 +100,8 @@ public boolean isIgnoreFailure() { return false; } - private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + @VisibleForTesting + boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { return true; } @@ -112,7 +114,8 @@ private boolean shouldSkipProcessor(SearchPha * @param searchPhaseResult * @return true if results are from hybrid query */ - private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + @VisibleForTesting + boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { // check for delimiter at the end of the score docs. return Objects.nonNull(searchPhaseResult.queryResult()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) @@ -121,9 +124,7 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); } - private List getQueryPhaseSearchResults( - final SearchPhaseResults results - ) { + List getQueryPhaseSearchResults(final SearchPhaseResults results) { return results.getAtomicArray() .asList() .stream() @@ -131,7 +132,8 @@ private List getQueryPhase .collect(Collectors.toList()); } - private Optional getFetchSearchResults( + @VisibleForTesting + Optional getFetchSearchResults( final SearchPhaseResults searchPhaseResults ) { Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index da16d6c96..7da4c4330 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -12,6 +12,8 @@ import java.util.Map; import java.util.Objects; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -58,8 +60,8 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { scoreDoc.score = normalizeSingleScore( scoreDoc.score, - minMaxScores.minScoresPerSubquery()[j], - minMaxScores.maxScoresPerSubquery()[j] + minMaxScores.getMinScoresPerSubquery()[j], + minMaxScores.getMaxScoresPerSubquery()[j] ); } } @@ -96,8 +98,8 @@ public Map explain(final List new ArrayList<>()).add(normalizedScore); scoreDoc.score = normalizedScore; @@ -171,6 +173,10 @@ private float normalizeSingleScore(final float score, final float minScore, fina /** * Result class to hold min and max scores for each sub query */ - private record MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) { + @AllArgsConstructor + @Getter + private class MinMaxScores { + float[] minScoresPerSubquery; + float[] maxScoresPerSubquery; } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java new file mode 100644 index 000000000..fccabab5c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; + +public class RRFProcessorIT extends BaseNeuralSearchIT { + + private int currentDoc = 1; + private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline"; + + private static final int RRF_DIMENSION = 5; + + @SneakyThrows + public void testRRF_whenValidInput_thenSucceed() { + try { + createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING); + prepareKnnIndex( + RRF_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE)) + ); + addDocuments(); + createDefaultRRFSearchPipeline(); + + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(); + + Map results = search( + RRF_INDEX_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE) + ); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder getHybridQueryBuilder() { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding") + .k(5) + .vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(knnQueryBuilder); + return hybridQueryBuilder; + } + + @SneakyThrows + private void addDocuments() { + addDocument( + "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", + "4319130149.jpg" + ); + addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addDocument( + "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", + "2664027527.jpg" + ); + addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + } + + @SneakyThrows + private void addDocument(String description, String imageText) { + addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java new file mode 100644 index 000000000..b7764128f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RRFProcessorTests extends OpenSearchTestCase { + + @Mock + private ScoreNormalizationTechnique mockNormalizationTechnique; + @Mock + private ScoreCombinationTechnique mockCombinationTechnique; + @Mock + private NormalizationProcessorWorkflow mockNormalizationWorkflow; + @Mock + private SearchPhaseResults mockSearchPhaseResults; + @Mock + private SearchPhaseContext mockSearchPhaseContext; + @Mock + private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; + + private RRFProcessor rrfProcessor; + private static final String TAG = "tag"; + private static final String DESCRIPTION = "description"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow); + } + + @SneakyThrows + public void testGetType() { + assertEquals(RRFProcessor.TYPE, rrfProcessor.getType()); + } + + @SneakyThrows + public void testGetBeforePhase() { + assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase()); + } + + @SneakyThrows + public void testGetAfterPhase() { + assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase()); + } + + @SneakyThrows + public void testIsIgnoreFailure() { + assertFalse(rrfProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() { + rrfProcessor.process(null, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() { + rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenValidHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testProcess_whenValidNonHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(false); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow, never()).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testGetTag() { + assertEquals(TAG, rrfProcessor.getTag()); + } + + @SneakyThrows + public void testGetDescription() { + assertEquals(DESCRIPTION, rrfProcessor.getDescription()); + } + + @SneakyThrows + public void testShouldSkipProcessor() { + assertTrue(rrfProcessor.shouldSkipProcessor(null)); + assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults)); + + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + + atomicArray.set(0, createQuerySearchResult(true)); + assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + } + + @SneakyThrows + public void testGetQueryPhaseSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(2); + atomicArray.set(0, createQuerySearchResult(true)); + atomicArray.set(1, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + List results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer); + assertEquals(2, results.size()); + assertNotNull(results.get(0)); + assertNotNull(results.get(1)); + } + + @SneakyThrows + public void testGetFetchSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(true)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + Optional result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer); + assertFalse(result.isPresent()); + } + + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { + ShardId shardId = new ShardId("index", "uuid", 0); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.source(new SearchSourceBuilder()); + searchRequest.allowPartialSearchResults(true); + + int numberOfShards = 1; + AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY); + float indexBoost = 1.0f; + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = null; + String[] indexRoutings = Strings.EMPTY_ARRAY; + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + numberOfShards, + aliasFilter, + indexBoost, + nowInMillis, + clusterAlias, + indexRoutings + ); + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", shardId, clusterAlias, originalIndices), + shardSearchRequest + ); + result.from(0).size(10); + + ScoreDoc[] scoreDocs; + if (isHybrid) { + scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) }; + } else { + scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) }; + } + + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f); + result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]); + + return result; + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 00b75c575..965500497 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -93,6 +93,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; + protected static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -1555,4 +1556,31 @@ protected enum ProcessorType { SPARSE_ENCODING, SPARSE_ENCODING_PRUNE } + + @SneakyThrows + protected void createDefaultRRFSearchPipeline() { + String requestBody = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } } From 7d6599c98b13551bf2e00846c25d2aef509b6e6f Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 23 Dec 2024 08:53:43 -0800 Subject: [PATCH 04/16] Integrate explainability for hybrid query into RRF processor (#1037) * Integrate explainability for hybrid query into RRF processor Signed-off-by: Martin Gaievski --- .../AbstractScoreHybridizationProcessor.java | 65 +++++++ .../ExplanationResponseProcessor.java | 3 +- .../processor/NormalizationProcessor.java | 36 +--- .../neuralsearch/processor/RRFProcessor.java | 19 +- .../RRFScoreCombinationTechnique.java | 18 +- .../combination/ScoreCombinationFactory.java | 2 +- .../RRFNormalizationTechnique.java | 71 ++++++-- ...tractScoreHybridizationProcessorTests.java | 152 ++++++++++++++++ ...=> ExplanationResponseProcessorTests.java} | 116 ++++++++++++- .../processor/RRFProcessorTests.java | 33 ++++ .../RRFScoreCombinationTechniqueTests.java | 44 ++++- .../RRFNormalizationTechniqueTests.java | 54 ++++++ .../query/HybridQueryExplainIT.java | 162 ++++++++++++++++-- .../neuralsearch/BaseNeuralSearchIT.java | 24 ++- 14 files changed, 708 insertions(+), 91 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationPayloadProcessorTests.java => ExplanationResponseProcessorTests.java} (76%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java new file mode 100644 index 000000000..456e8415a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import java.util.Optional; + +/** + * Base class for all score hybridization processors. This class is responsible for executing the score hybridization process. + * It is a pipeline processor that is executed after the query phase and before the fetch phase. + */ +public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor { + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is no pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult + * @param searchPhaseContext + * @param requestContextOptional + * @param + */ + abstract void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 1cdd69b15..3423a2e29 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -111,8 +111,9 @@ public SearchResponse processResponse( ); } // Create and set final explanation combining all components + Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore(); Explanation finalExplanation = Explanation.match( - searchHit.getScore(), + finalScore, // combination level explanation is always a single detail combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2fa03fde..80499543e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -19,9 +19,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; -import org.opensearch.search.internal.SearchContext; import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; @@ -33,7 +31,7 @@ */ @Log4j2 @AllArgsConstructor -public class NormalizationProcessor implements SearchPhaseResultsProcessor { +public class NormalizationProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "normalization-processor"; private final String tag; @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { private final ScoreCombinationTechnique combinationTechnique; private final NormalizationProcessorWorkflow normalizationWorkflow; - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor. This method is called when there is no pipeline context - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); - } - - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline - * @param - */ - @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); - } - - private void prepareAndExecuteNormalizationWorkflow( + void hybridizeScores( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index ca67f2d1c..100cf9fc6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -12,11 +12,12 @@ import java.util.Objects; import java.util.Optional; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; -import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.action.search.QueryPhaseResultConsumer; @@ -25,7 +26,6 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -39,7 +39,7 @@ */ @Log4j2 @AllArgsConstructor -public class RRFProcessor implements SearchPhaseResultsProcessor { +public class RRFProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "score-ranker-processor"; @Getter @@ -57,9 +57,10 @@ public class RRFProcessor implements SearchPhaseResultsProcessor { * @param searchPhaseContext {@link SearchContext} */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional ) { if (shouldSkipProcessor(searchPhaseResult)) { log.debug("Query results are not compatible with RRF processor"); @@ -67,7 +68,8 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending // on coming from NormalizationProcessor or RRFProcessor NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() @@ -75,7 +77,8 @@ public void process( .fetchSearchResultOptional(fetchSearchResult) .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) - .explain(false) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) .build(); normalizationWorkflow.execute(normalizationExecuteDTO); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java index befe14dda..6d6c94b94 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -6,27 +6,39 @@ import lombok.ToString; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import java.util.Map; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; @Log4j2 /** * Abstracts combination of scores based on reciprocal rank fusion algorithm */ @ToString(onlyExplicitlyIncluded = true) -public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; // Not currently using weights for RRF, no need to modify or verify these params - public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) {} + public RRFScoreCombinationTechnique() {} @Override public float combine(final float[] scores) { + if (Objects.isNull(scores)) { + throw new IllegalArgumentException("scores array cannot be null"); + } float sumScores = 0.0f; for (float score : scores) { sumScores += score; } return sumScores; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, List.of()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 1e560342a..3f1996424 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -27,7 +27,7 @@ public class ScoreCombinationFactory { GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), RRFScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new RRFScoreCombinationTechnique() ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index 16ef83d05..80fc65eb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -6,27 +6,34 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Locale; import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; import org.apache.commons.lang3.Range; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts calculation of rank scores for each document returned as part of * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. */ @ToString(onlyExplicitlyIncluded = true) -public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; public static final int DEFAULT_RANK_CONSTANT = 60; @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map params, final ScoreNo public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { - if (Objects.isNull(compoundQueryTopDocs)) { - continue; - } - List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - for (TopDocs topDocs : topDocsPerSubQuery) { - int docsCountPerSubQuery = topDocs.scoreDocs.length; - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - for (int j = 0; j < docsCountPerSubQuery; j++) { - // using big decimal approach to minimize error caused by floating point ops - // score = 1.f / (float) (rankConstant + j + 1)) - scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) - .floatValue(); - } - } + processTopDocs(compoundQueryTopDocs, (docId, score) -> {}); + } + } + + @Override + public String describe() { + return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + processTopDocs( + compoundQueryTopDocs, + (docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score) + ); } + + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + + private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer scoreProcessor) { + if (Objects.isNull(compoundQueryTopDocs)) { + return; + } + + compoundQueryTopDocs.getTopDocs().forEach(topDocs -> { + IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> { + float normalizedScore = calculateNormalizedScore(position); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( + topDocs.scoreDocs[position].doc, + compoundQueryTopDocs.getSearchShard() + ); + scoreProcessor.accept(docIdAtSearchShard, normalizedScore); + topDocs.scoreDocs[position].score = normalizedScore; + }); + }); + } + + private float calculateNormalizedScore(int position) { + return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue(); } private int getRankConstant(final Map params) { @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) { } } - public static int getParamAsInteger(final Map parameters, final String fieldName) { + private static int getParamAsInteger(final Map parameters, final String fieldName) { try { return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); } catch (NumberFormatException e) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java new file mode 100644 index 000000000..4e9ab59e5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AbstractScoreHybridizationProcessorTests extends OpenSearchTestCase { + private static final String TEST_TAG = "test_processor"; + private static final String TEST_DESCRIPTION = "Test Processor"; + + private TestScoreHybridizationProcessor processor; + private NormalizationProcessorWorkflow normalizationWorkflow; + + private static class TestScoreHybridizationProcessor extends AbstractScoreHybridizationProcessor { + private final String tag; + private final String description; + private final NormalizationProcessorWorkflow normalizationWorkflow1; + + TestScoreHybridizationProcessor(String tag, String description, NormalizationProcessorWorkflow normalizationWorkflow) { + this.tag = tag; + this.description = description; + normalizationWorkflow1 = normalizationWorkflow; + } + + @Override + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ) { + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .build(); + normalizationWorkflow1.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.FETCH; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.QUERY; + } + + @Override + public String getType() { + return "my_processor"; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + } + + @Before + public void setup() { + normalizationWorkflow = mock(NormalizationProcessorWorkflow.class); + + processor = new TestScoreHybridizationProcessor(TEST_TAG, TEST_DESCRIPTION, normalizationWorkflow); + } + + public void testProcessorMetadata() { + assertEquals(TEST_TAG, processor.getTag()); + assertEquals(TEST_DESCRIPTION, processor.getDescription()); + } + + public void testProcessWithExplanations() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchPhaseContext context = mock(SearchPhaseContext.class); + SearchPhaseResults results = mock(SearchPhaseResults.class); + + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(context.getRequest()).thenReturn(searchRequest); + + AtomicArray resultsArray = new AtomicArray<>(1); + QuerySearchResult queryResult = createQuerySearchResult(); + resultsArray.set(0, queryResult); + when(results.getAtomicArray()).thenReturn(resultsArray); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(results, context); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + verify(normalizationWorkflow).execute(any()); + } + + public void testProcess() { + SearchPhaseResults searchPhaseResult = mock(SearchPhaseResults.class); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + PipelineProcessingContext requestContext = mock(PipelineProcessingContext.class); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(searchPhaseResult, searchPhaseContext, requestContext); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + } + + private QuerySearchResult createQuerySearchResult() { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE), + null + ); + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, 1.0f) }); + result.topDocs(new TopDocsAndMaxScore(topDocs, 1.0f), new DocValueFormat[0]); + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java index 2e603d078..bfcd14251 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -37,9 +37,10 @@ import java.util.TreeMap; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces assertOnExplanationResults(processedResponse, maxScore); } + @SneakyThrows + public void testProcessResponse_whenNullSearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = getSearchResponse(null); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenEmptySearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponse searchResponse = getSearchResponse(emptyHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenNullExplanation_thenSkipProcessing() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + for (SearchHit hit : searchHits.getHits()) { + hit.explanation(null); + } + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Set invalid payload + Map invalidPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + "invalid payload" + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenZeroScore_thenProcessCorrectly() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(0.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + // Create SearchHits with NaN score + SearchHits searchHits = getSearchHits(Float.NaN); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Setup explanation payload + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + // Process response + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + + // Verify results + assertNotNull(processedResponse); + SearchHit[] hits = processedResponse.getHits().getHits(); + assertNotNull(hits); + assertTrue(hits.length > 0); + + // Verify that the explanation uses 0.0f when input score was NaN + Explanation explanation = hits[0].getExplanation(); + assertNotNull(explanation); + assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION); + } + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java index b7764128f..753c0b8fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.OriginalIndices; @@ -111,6 +112,10 @@ public void testProcess_whenValidHybridInput_thenSucceed() { when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder()); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); @@ -177,6 +182,34 @@ public void testGetFetchSearchResults() { assertFalse(result.isPresent()); } + @SneakyThrows + public void testProcess_whenExplainIsTrue_thenExplanationIsAdded() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + // Capture the actual request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass( + NormalizationProcessorWorkflowExecuteRequest.class + ); + verify(mockNormalizationWorkflow).execute(requestCaptor.capture()); + + // Verify the captured request + NormalizationProcessorWorkflowExecuteRequest capturedRequest = requestCaptor.getValue(); + assertTrue(capturedRequest.isExplain()); + assertNull(capturedRequest.getPipelineProcessingContext()); + } + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { ShardId shardId = new ShardId("index", "uuid", 0); OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java index daed466d3..39b4dd4e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -5,26 +5,62 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Map; public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { - private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + private static final int RANK_CONSTANT = 60; + private RRFScoreCombinationTechnique combinationTechnique; public RRFScoreCombinationTechniqueTests() { this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + combinationTechnique = new RRFScoreCombinationTechnique(); } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } + public void testDescribe() { + String description = combinationTechnique.describe(); + assertEquals("rrf", description); + } + + public void testCombineWithEmptyInput() { + float[] scores = new float[0]; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithSingleScore() { + float[] scores = new float[] { 0.5f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.5f, result, 0.001f); + } + + public void testCombineWithMultipleScores() { + float[] scores = new float[] { 0.8f, 0.6f, 0.4f }; + float result = combinationTechnique.combine(scores); + float expected = 0.8f + 0.6f + 0.4f; + assertEquals(expected, result, 0.001f); + } + + public void testCombineWithZeroScores() { + float[] scores = new float[] { 0.0f, 0.0f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithNullInput() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> combinationTechnique.combine(null)); + assertEquals("scores array cannot be null", exception.getMessage()); + } + private float RRF(List scores, List weights) { float sumScores = 0.0f; for (float score : scores) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java index 1e1089846..da6d37bd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -10,10 +10,14 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -25,6 +29,16 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testDescribe() { + // verify with default values for parameters + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe()); + + // verify when parameter values are set + normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe()); + } + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); float[] scores = { 0.5f, 0.2f }; @@ -224,6 +238,27 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the } } + public void testExplainWithEmptyAndNullList() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + normalizationTechnique.explain(List.of()); + + List compoundTopDocs = new ArrayList<>(); + compoundTopDocs.add(null); + normalizationTechnique.explain(compoundTopDocs); + } + + public void testExplainWithSingleTopDocs() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + CompoundTopDocs topDocs = createCompoundTopDocs(new float[] { 0.8f }, 1); + List queryTopDocs = Collections.singletonList(topDocs); + + Map explanation = normalizationTechnique.explain(queryTopDocs); + + assertNotNull(explanation); + assertEquals(1, explanation.size()); + assertTrue(explanation.containsKey(new DocIdAtSearchShard(0, new SearchShard("test_index", 0, "uuid")))); + } + private float rrfNorm(int rank) { // 1.0f / (float) (rank + RANK_CONSTANT + 1); return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); @@ -239,4 +274,23 @@ private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); } } + + private CompoundTopDocs createCompoundTopDocs(float[] scores, int size) { + ScoreDoc[] scoreDocs = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + scoreDocs[i] = new ScoreDoc(i, scores[i]); + } + TopDocs singleTopDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + + List topDocsList = Collections.singletonList(singleTopDocs); + TopDocs topDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + SearchShard searchShard = new SearchShard("test_index", 0, "uuid"); + + return new CompoundTopDocs( + new TotalHits(size, TotalHits.Relation.EQUAL_TO), + topDocsList, + false, // isSortEnabled + searchShard + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index b7e4f753a..c6eaa21ff 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -58,7 +58,8 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String NORMALIZATION_SEARCH_PIPELINE = "normalization-search-pipeline"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; @@ -78,7 +79,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -95,7 +96,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -187,7 +188,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -196,7 +197,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipeline( - SEARCH_PIPELINE, + NORMALIZATION_SEARCH_PIPELINE, NORMALIZATION_TECHNIQUE_L2, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), @@ -217,7 +218,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() hybridQueryBuilder, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // basic sanity check for search hits @@ -322,7 +323,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(0, getListOfValues(explanationsHit4, "details").size()); assertTrue((double) explanationsHit4.get("value") > 0.0f); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -331,7 +332,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with normalization processor, no explanation response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -348,7 +349,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -463,7 +464,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe assertEquals("boost", explanationsHit3Details.get("description")); assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -472,7 +473,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -483,7 +484,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -521,7 +522,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -530,7 +531,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); @@ -543,7 +544,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -581,7 +582,136 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenRRFProcessor_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "rrf combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > .0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, RRF_SEARCH_PIPELINE); } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 965500497..95757e463 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -1559,7 +1559,12 @@ protected enum ProcessorType { @SneakyThrows protected void createDefaultRRFSearchPipeline() { - String requestBody = XContentFactory.jsonBuilder() + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, false); + } + + @SneakyThrows + protected void createRRFSearchPipeline(final String pipelineName, boolean addExplainResponseProcessor) { + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .field("description", "Post processor for hybrid search") .startArray("phase_results_processors") @@ -1570,14 +1575,23 @@ protected void createDefaultRRFSearchPipeline() { .endObject() .endObject() .endObject() - .endArray() - .endObject() - .toString(); + .endArray(); + + if (addExplainResponseProcessor) { + builder.startArray("response_processors") + .startObject() + .startObject("explanation_response_processor") + .endObject() + .endObject() + .endArray(); + } + + String requestBody = builder.endObject().toString(); makeRequest( client(), "PUT", - String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + String.format(LOCALE, "/_search/pipeline/%s", pipelineName), null, toHttpEntity(String.format(LOCALE, requestBody)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) From 9ce78f537169c1691c37faf11342e46dfccd8fd8 Mon Sep 17 00:00:00 2001 From: Bo Zhang Date: Tue, 17 Dec 2024 16:11:11 -0800 Subject: [PATCH 05/16] Support of new k-NN query parameter expand_nested. (#1013) Signed-off-by: Bo Zhang --- .github/workflows/CI.yml | 1 - .../org/opensearch/neuralsearch/bwc/HybridSearchIT.java | 5 +++++ .../opensearch/neuralsearch/bwc/KnnRadialSearchIT.java | 1 - .../org/opensearch/neuralsearch/bwc/HybridSearchIT.java | 9 ++++++--- .../opensearch/neuralsearch/bwc/KnnRadialSearchIT.java | 1 - .../opensearch/neuralsearch/bwc/MultiModalSearchIT.java | 1 - .../neuralsearch/query/NeuralQueryBuilder.java | 1 - .../neuralsearch/processor/NormalizationProcessorIT.java | 1 - .../neuralsearch/processor/ScoreCombinationIT.java | 3 --- 9 files changed, 11 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index aa17e9ff3..4ee84ec8c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -138,4 +138,3 @@ jobs: - name: Run build run: | ./gradlew precommit --parallel - diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 56ffec24a..5cf6ab02c 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -13,6 +13,8 @@ import org.opensearch.index.query.MatchQueryBuilder; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; @@ -131,6 +133,9 @@ private HybridQueryBuilder getQueryBuilder( if (expandNestedDocs != null) { neuralQueryBuilder.expandNested(expandNestedDocs); } + if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName()) && expandNestedDocs != null) { + neuralQueryBuilder.expandNested(expandNestedDocs); + } if (methodParameters != null) { neuralQueryBuilder.methodParameters(methodParameters); } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index c3f461871..838d7ae8a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -69,7 +69,6 @@ private void validateIndexQuery(final String modelId) { .modelId(modelId) .maxDistance(100000f) .build(); - Map responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); assertNotNull(responseWithMaxDistanceQuery); } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 44671ed4a..f64ddd455 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -20,6 +20,7 @@ import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -67,7 +68,8 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; @@ -82,9 +84,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault()); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index 88af8b757..52d2ee173 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -95,7 +95,6 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo .modelId(modelId) .maxDistance(100000f) .build(); - Map responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); assertNotNull(responseWithMaxScore); } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index e2df88d6d..4dc33a15b 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -83,7 +83,6 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod .modelId(modelId) .k(1) .build(); - Map responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1); assertNotNull(responseWithKQuery); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 42d56b85c..bdaea1567 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -302,7 +302,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) { out.writeOptionalBoolean(this.expandNested); } - if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3c5fc08ef..72b74b28c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -139,7 +139,6 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu .modelId(modelId) .k(5) .build(); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index feb914e30..88dbfa735 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -250,7 +250,6 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() - ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -301,7 +300,6 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() - ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -327,7 +325,6 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() - ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); From 1339c79819e7a97714009e211bf8e4d3e04eda89 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 18 Dec 2024 10:52:25 +0800 Subject: [PATCH 06/16] [Enhancement] Implement pruning for neural sparse search (#988) * add impl Signed-off-by: zhichao-aws * add UT Signed-off-by: zhichao-aws * rename pruneType; UT Signed-off-by: zhichao-aws * changelog Signed-off-by: zhichao-aws * ut Signed-off-by: zhichao-aws * add it Signed-off-by: zhichao-aws * change on 2-phase Signed-off-by: zhichao-aws * UT Signed-off-by: zhichao-aws * it Signed-off-by: zhichao-aws * rename Signed-off-by: zhichao-aws * enhance: more detailed error message Signed-off-by: zhichao-aws * refactor to prune and split Signed-off-by: zhichao-aws * changelog Signed-off-by: zhichao-aws * fix UT cov Signed-off-by: zhichao-aws * address review comments Signed-off-by: zhichao-aws * enlarge score diff range Signed-off-by: zhichao-aws * address comments: check lowScores non null instead of flag Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws --- .../neuralsearch/processor/SparseEncodingProcessIT.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index b5e14a11f..84ab61750 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -4,9 +4,18 @@ */ package org.opensearch.neuralsearch.processor; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Map; +import com.google.common.collect.ImmutableList; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; From 0afd102908ab34a13ade8680c6460f0204626e39 Mon Sep 17 00:00:00 2001 From: Bo Zhang Date: Wed, 18 Dec 2024 10:05:55 -0800 Subject: [PATCH 07/16] Remove mistakenly added code from HybridSearchIT. (#1032) Signed-off-by: Bo Zhang --- .../java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 5cf6ab02c..81b861187 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -133,9 +133,6 @@ private HybridQueryBuilder getQueryBuilder( if (expandNestedDocs != null) { neuralQueryBuilder.expandNested(expandNestedDocs); } - if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName()) && expandNestedDocs != null) { - neuralQueryBuilder.expandNested(expandNestedDocs); - } if (methodParameters != null) { neuralQueryBuilder.methodParameters(methodParameters); } From d5a176e623e8348bd5695b5badd55315d1d0322a Mon Sep 17 00:00:00 2001 From: Yizhe Liu <59710443+yizheliu-amazon@users.noreply.github.com> Date: Fri, 3 Jan 2025 09:16:37 -0800 Subject: [PATCH 08/16] Fix bug where ingestion failed for input document containing list of nested objects (#1040) * Fix bug where ingestion failed for input document containing list of nested objects Signed-off-by: Yizhe Liu * Address comments to use better method name/implementation Signed-off-by: Yizhe Liu * Address comments: modify the test case to have doc with various fields Signed-off-by: Yizhe Liu --------- Signed-off-by: Yizhe Liu --- .../neuralsearch/bwc/HybridSearchIT.java | 2 - .../processor/InferenceProcessor.java | 48 +- .../processor/SparseEncodingProcessIT.java | 9 - .../TextEmbeddingProcessorTests.java | 443 +++++------------- 4 files changed, 137 insertions(+), 365 deletions(-) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 81b861187..56ffec24a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -13,8 +13,6 @@ import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; -import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 3fb45ceeb..ff1b663f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -137,7 +137,6 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { - preprocessIngestDocument(ingestDocument); validateEmbeddingFieldsValue(ingestDocument); Map processMap = buildMapWithTargetKeys(ingestDocument); List inferenceList = createInferenceList(processMap); @@ -151,15 +150,6 @@ public void execute(IngestDocument ingestDocument, BiConsumer sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap); - unflattened.forEach(ingestDocument::setFieldValue); - sourceAndMetadataMap.keySet().removeIf(key -> key.contains(".")); - } - /** * This is the function which does actual inference work for batchExecute interface. * @param inferenceList a list of String for inference. @@ -254,14 +244,12 @@ private List getDataForInference(List i for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) { Map processMap = null; List inferenceList = null; - IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument(); try { - preprocessIngestDocument(ingestDocument); - validateEmbeddingFieldsValue(ingestDocument); - processMap = buildMapWithTargetKeys(ingestDocument); + validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument()); + processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument()); inferenceList = createInferenceList(processMap); } catch (Exception e) { - ingestDocumentWrapper.update(ingestDocument, e); + ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e); } finally { dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList)); } @@ -319,7 +307,7 @@ Map buildMapWithTargetKeys(IngestDocument ingestDocument) { buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes); mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey))); + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); } } return mapWithProcessorKeys; @@ -345,33 +333,17 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { List> list = (List>) sourceAndMetadataMap.get(parentKey); - Pair processedNestedKey = processNestedKey(nestedFieldMapEntry); - List listOfStrings = list.stream().map(x -> { - Object nestedSourceValue = x.get(processedNestedKey.getKey()); - return normalizeSourceValue(nestedSourceValue); - }).collect(Collectors.toList()); + List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); Map map = new LinkedHashMap<>(); - map.put(processedNestedKey.getKey(), listOfStrings); - buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next); + map.put(nestedFieldMapEntry.getKey(), listOfStrings); + buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next); } } treeRes.merge(parentKey, next, REMAPPING_FUNCTION); } else { - Object parentValue = sourceAndMetadataMap.get(parentKey); String key = String.valueOf(processorKey); - treeRes.put(key, normalizeSourceValue(parentValue)); - } - } - - private boolean isBlankString(Object object) { - return object instanceof String && StringUtils.isBlank((String) object); - } - - private Object normalizeSourceValue(Object value) { - if (isBlankString(value)) { - return null; + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } - return value; } /** @@ -400,11 +372,11 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { ProcessorDocumentUtils.validateMapTypeValue( FIELD_MAP_FIELD, sourceAndMetadataMap, - ProcessorDocumentUtils.unflattenJson(fieldMap), + fieldMap, indexName, clusterService, environment, - true + false ); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 84ab61750..b5e14a11f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -4,18 +4,9 @@ */ package org.opensearch.neuralsearch.processor; -import java.nio.file.Files; -import java.nio.file.Path; import java.util.Map; -import com.google.common.collect.ImmutableList; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; -import org.opensearch.client.Response; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 8fedd1fca..a0ba8780c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -29,7 +29,6 @@ import java.util.function.Supplier; import java.util.stream.IntStream; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -246,6 +245,31 @@ public void testExecute_withListTypeInput_successful() { verify(handler).accept(any(IngestDocument.class), isNull()); } + public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", " "); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() { + List list1 = ImmutableList.of("", "test2", "test3"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", list1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { List list2 = ImmutableList.of(1, 2, 3); Map sourceAndMetadata = new HashMap<>(); @@ -486,34 +510,28 @@ public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWitho ] */ - Map child1Level2 = buildObjMap(Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1)); - Map child1Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child1Level2)); - Map child2Level2 = buildObjMap( - Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1), - Pair.of(CHILD_2_TEXT_FIELD, TEXT_VALUE_2), - Pair.of(CHILD_3_TEXT_FIELD, TEXT_VALUE_3) - ); - Map child2Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child2Level2)); - Map sourceAndMetadata = buildObjMap( - Pair.of(PARENT_FIELD, Arrays.asList(child1Level1, child2Level1)), - Pair.of(IndexFieldMapper.NAME, "my_index") + Map child1Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map child1Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child1Level2); + Map child2Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + child2Level2.put(CHILD_2_TEXT_FIELD, TEXT_VALUE_2); + child2Level2.put(CHILD_3_TEXT_FIELD, TEXT_VALUE_3); + Map child2Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child2Level2); + Map sourceAndMetadata = Map.of( + PARENT_FIELD, + Arrays.asList(child1Level1, child2Level1), + IndexFieldMapper.NAME, + "my_index" ); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); - Map config = buildObjMap( - Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), - Pair.of( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - buildObjMap( - Pair.of( - PARENT_FIELD, - Map.of( - CHILD_FIELD_LEVEL_1, - Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD)) - ) - ) - ) + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + Map.of( + PARENT_FIELD, + Map.of(CHILD_FIELD_LEVEL_1, Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD))) ) ); TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( @@ -614,6 +632,20 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } + public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", " "); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("key2", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map ret = createMaxDepthLimitExceedMap(() -> 1); Map sourceAndMetadata = new HashMap<>(); @@ -776,103 +808,6 @@ public void testBuildVectorOutput_withNestedMap_successful() { } } - @SneakyThrows - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withFlattenedNestedMap_successful() { - Map config = createNestedMapConfiguration(); - IngestDocument ingestDocument = createFlattenedNestedMapIngestDocument(); - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - processor.preprocessIngestDocument(ingestDocument); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f); - processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); - /** - * "favorites.favorite": { - * "movie": "matrix", - * "actor": "Charlie Chaplin", - * "games" : { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * } - * } - */ - Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - assertNotNull(favoritesMap); - Map favorites = (Map) favoritesMap.get("favorite"); - assertNotNull(favorites); - - Map favoriteGames = (Map) favorites.get("games"); - assertNotNull(favoriteGames); - Map adventure = (Map) favoriteGames.get("adventure"); - List adventureKnnVector = (List) adventure.get("with_action_knn"); - assertNotNull(adventureKnnVector); - assertEquals(100, adventureKnnVector.size()); - for (float vector : adventureKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - - List favoriteKnnVector = (List) favorites.get("favorite_movie_knn"); - assertNotNull(favoriteKnnVector); - assertEquals(100, favoriteKnnVector.size()); - for (float vector : favoriteKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - } - - @SneakyThrows - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withFlattenedNestedMapAndList_successful() { - Map config = createNestedMapConfiguration(); - IngestDocument ingestDocument = createFlattenedNestedMapAndListIngestDocument(); - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - processor.preprocessIngestDocument(ingestDocument); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(3, 100, 0.0f, 1.0f); - processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); - /** - * "favorites.favorite": { - * "movie": "matrix", - * "actor": "Charlie Chaplin", - * "games" : [ - * { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * }, - * { - * "adventure.action": "wukong" - * } - * ] - * } - */ - Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - assertNotNull(favoritesMap); - Map favorite = (Map) favoritesMap.get("favorite"); - assertNotNull(favorite); - - List> favoriteGames = (List>) favorite.get("games"); - assertNotNull(favoriteGames); - for (Map favoriteGame : favoriteGames) { - Map adventure = (Map) favoriteGame.get("adventure"); - List adventureKnnVector = (List) adventure.get("with_action_knn"); - assertNotNull(adventureKnnVector); - assertEquals(100, adventureKnnVector.size()); - for (float vector : adventureKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - } - - List favoriteKnnVector = (List) favorite.get("favorite_movie_knn"); - assertNotNull(favoriteKnnVector); - assertEquals(100, favoriteKnnVector.size()); - for (float vector : favoriteKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - } - public void testBuildVectorOutput_withNestedList_successful() { Map config = createNestedListConfiguration(); IngestDocument ingestDocument = createNestedListIngestDocument(); @@ -956,8 +891,8 @@ public void testBuildVectorOutput_withNestedListLevel2_withPartialNullNestedFiel * } */ List> nestedList = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); - Map objWithNullText = buildObjMap(Pair.of("textField", null)); - Map nestedObjWithNullText = buildObjMap(Pair.of("nestedField", objWithNullText)); + Map objWithNullText = buildObjMapWithSingleField("textField", null); + Map nestedObjWithNullText = buildObjMapWithSingleField("nestedField", objWithNullText); nestedList.set(0, nestedObjWithNullText); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); @@ -995,79 +930,6 @@ public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_s assertNotNull(nestedObj.get(1).get("vectorField")); } - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withPlainString_EmptyString_skipped() { - Map config = createPlainStringConfiguration(); - IngestDocument ingestDocument = createPlainIngestDocument(); - Map sourceAndMetadata = ingestDocument.getSourceAndMetadata(); - sourceAndMetadata.put("oriKey1", StringUtils.EMPTY); - - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(6, 100, 0.0f, 1.0f); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); - - /** IngestDocument - * "oriKey1": "", - * "oriKey2": "oriValue2", - * "oriKey3": "oriValue3", - * "oriKey4": "oriValue4", - * "oriKey5": "oriValue5", - * "oriKey6": [ - * "oriValue6", - * "oriValue7" - * ] - * - */ - assertEquals(11, sourceAndMetadata.size()); - assertFalse(sourceAndMetadata.containsKey("oriKey1_knn")); - } - - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withNestedField_EmptyString_skipped() { - Map config = createNestedMapConfiguration(); - IngestDocument ingestDocument = createNestedMapIngestDocument(); - Map favorites = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - Map favorite = (Map) favorites.get("favorite"); - favorite.put("movie", StringUtils.EMPTY); - - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); - processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); - - /** - * "favorites": { - * "favorite": { - * "movie": "", - * "actor": "Charlie Chaplin", - * "games" : { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * } - * } - * } - */ - Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - assertNotNull(favoritesMap); - Map favoriteMap = (Map) favoritesMap.get("favorite"); - assertNotNull(favoriteMap); - - Map favoriteGames = (Map) favoriteMap.get("games"); - assertNotNull(favoriteGames); - Map adventure = (Map) favoriteGames.get("adventure"); - List adventureKnnVector = (List) adventure.get("with_action_knn"); - assertNotNull(adventureKnnVector); - assertEquals(100, adventureKnnVector.size()); - for (float vector : adventureKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - - assertFalse(favoriteMap.containsKey("favorite_movie_knn")); - } - public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); @@ -1245,22 +1107,21 @@ private void assertMapWithNestedFields(Pair actual, List @SneakyThrows private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map fieldMap) { Map registry = new HashMap<>(); - Map config = buildObjMap( - Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), - Pair.of(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap) - ); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap); return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } private Map createPlainStringConfiguration() { - return buildObjMap( - Pair.of("oriKey1", "oriKey1_knn"), - Pair.of("oriKey2", "oriKey2_knn"), - Pair.of("oriKey3", "oriKey3_knn"), - Pair.of("oriKey4", "oriKey4_knn"), - Pair.of("oriKey5", "oriKey5_knn"), - Pair.of("oriKey6", "oriKey6_knn") - ); + Map config = new HashMap<>(); + config.put("oriKey1", "oriKey1_knn"); + config.put("oriKey2", "oriKey2_knn"); + config.put("oriKey3", "oriKey3_knn"); + config.put("oriKey4", "oriKey4_knn"); + config.put("oriKey5", "oriKey5_knn"); + config.put("oriKey6", "oriKey6_knn"); + return config; } /** @@ -1273,24 +1134,24 @@ private Map createPlainStringConfiguration() { * } */ private Map createNestedMapConfiguration() { - Map adventureGames = buildObjMap(Pair.of("adventure.action", "with_action_knn")); - Map favorite = buildObjMap( - Pair.of("favorite.movie", "favorite_movie_knn"), - Pair.of("favorite.games", adventureGames) - ); - Map result = buildObjMap(Pair.of("favorites", favorite)); + Map adventureGames = new HashMap<>(); + adventureGames.put("adventure.action", "with_action_knn"); + Map favorite = new HashMap<>(); + favorite.put("favorite.movie", "favorite_movie_knn"); + favorite.put("favorite.games", adventureGames); + Map result = new HashMap<>(); + result.put("favorites", favorite); return result; } private IngestDocument createPlainIngestDocument() { - Map result = buildObjMap( - Pair.of("oriKey1", "oriValue1"), - Pair.of("oriKey2", "oriValue2"), - Pair.of("oriKey3", "oriValue3"), - Pair.of("oriKey4", "oriValue4"), - Pair.of("oriKey5", "oriValue5"), - Pair.of("oriKey6", ImmutableList.of("oriValue6", "oriValue7")) - ); + Map result = new HashMap<>(); + result.put("oriKey1", "oriValue1"); + result.put("oriKey2", "oriValue2"); + result.put("oriKey3", "oriValue3"); + result.put("oriKey4", "oriValue4"); + result.put("oriKey5", "oriValue5"); + result.put("oriKey6", ImmutableList.of("oriValue6", "oriValue7")); return new IngestDocument(result, new HashMap<>()); } @@ -1310,131 +1171,81 @@ private IngestDocument createPlainIngestDocument() { * } */ private IngestDocument createNestedMapIngestDocument() { - Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); - Map favGames = buildObjMap(Pair.of("adventure", adventureGames)); - Map favorites = buildObjMap( - Pair.of("movie", "matrix"), - Pair.of("games", favGames), - Pair.of("actor", "Charlie Chaplin") - ); - Map favorite = buildObjMap(Pair.of("favorite", favorites)); - Map result = buildObjMap(Pair.of("favorites", favorite)); - return new IngestDocument(result, new HashMap<>()); - } - - /** - * Create following document with flattened nested map - * "favorites.favorite": { - * "movie": "matrix", - * "actor": "Charlie Chaplin", - * "games" : { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * } - * } - */ - private IngestDocument createFlattenedNestedMapIngestDocument() { - Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); - Map favGames = buildObjMap(Pair.of("adventure", adventureGames)); - Map favorites = buildObjMap( - Pair.of("movie", "matrix"), - Pair.of("games", favGames), - Pair.of("actor", "Charlie Chaplin") - ); - Map result = buildObjMap(Pair.of("favorites.favorite", favorites)); - return new IngestDocument(result, new HashMap<>()); - } - - /** - * Create following document with flattened nested map and list - * "favorites.favorite": { - * "movie": "matrix", - * "actor": "Charlie Chaplin", - * "games" : [ - * { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * }, - * { - * "adventure.action": "wukong" - * } - * ] - * } - */ - private IngestDocument createFlattenedNestedMapAndListIngestDocument() { - Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); - Map game1 = buildObjMap(Pair.of("adventure", adventureGames)); - Map game2 = buildObjMap(Pair.of("adventure.action", "wukong")); - Map favorites = buildObjMap( - Pair.of("movie", "matrix"), - Pair.of("games", Arrays.asList(game1, game2)), - Pair.of("actor", "Charlie Chaplin") - ); - Map result = buildObjMap(Pair.of("favorites.favorite", favorites)); + Map adventureGames = new HashMap<>(); + adventureGames.put("action", "overwatch"); + adventureGames.put("rpg", "elden ring"); + Map favGames = new HashMap<>(); + favGames.put("adventure", adventureGames); + Map favorites = new HashMap<>(); + favorites.put("movie", "matrix"); + favorites.put("games", favGames); + favorites.put("actor", "Charlie Chaplin"); + Map favorite = new HashMap<>(); + favorite.put("favorite", favorites); + Map result = new HashMap<>(); + result.put("favorites", favorite); return new IngestDocument(result, new HashMap<>()); } private Map createNestedListConfiguration() { - Map nestedConfig = buildObjMap(Pair.of("textField", "vectorField")); - return buildObjMap(Pair.of("nestedField", nestedConfig)); + Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); + return buildObjMapWithSingleField("nestedField", nestedConfig); } private Map createNestedList2LevelConfiguration() { - Map nestedConfig = buildObjMap(Pair.of("textField", "vectorField")); - Map nestConfigLevel1 = buildObjMap(Pair.of("nestedField", nestedConfig)); - return buildObjMap(Pair.of("nestedField", nestConfigLevel1)); + Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); + Map nestConfigLevel1 = buildObjMapWithSingleField("nestedField", nestedConfig); + return buildObjMapWithSingleField("nestedField", nestConfigLevel1); } private IngestDocument createNestedListIngestDocument() { - Map nestedObj1 = buildObjMap(Pair.of("textField", "This is a text field")); - Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); - Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); + Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument createNestedListWithNotEmbeddingFieldIngestDocument() { - Map nestedObj1 = buildObjMap(Pair.of("textFieldNotForEmbedding", "This is a text field")); - Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); - Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); + Map nestedObj1 = buildObjMapWithSingleField("textFieldNotForEmbedding", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument create2LevelNestedListIngestDocument() { - Map nestedObj1 = buildObjMap(Pair.of("textField", "This is a text field")); - Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); - Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); - Map nestedList1 = buildObjMap(Pair.of("nestedField", nestedList)); + Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedList1 = buildObjMapWithSingleField("nestedField", nestedList); return new IngestDocument(nestedList1, new HashMap<>()); } private IngestDocument create2LevelNestedListWithNestedFieldsIngestDocument() { - Map nestedObj1Level2 = buildObjMap(Pair.of("textField", "This is a text field")); - Map nestedObj1Level1 = buildObjMap(Pair.of("nestedField", nestedObj1Level2)); + Map nestedObj1Level2 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj1Level1 = buildObjMapWithSingleField("nestedField", nestedObj1Level2); - Map nestedObj2Level2 = buildObjMap(Pair.of("textField", "This is another text field")); - Map nestedObj2Level1 = buildObjMap(Pair.of("nestedField", nestedObj2Level2)); + Map nestedObj2Level2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedObj2Level1 = buildObjMapWithSingleField("nestedField", nestedObj2Level2); - Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1))); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1)); return new IngestDocument(nestedList, new HashMap<>()); } - private Map buildObjMap(Pair... pairs) { + private Map buildObjMapWithSingleField(String fieldName, Object fieldValue) { Map objMap = new HashMap<>(); - for (Pair pair : pairs) { - objMap.put(pair.getKey(), pair.getValue()); - } + objMap.put(fieldName, fieldValue); return objMap; } private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument() { - Map nestedObj1 = buildObjMap(Pair.of("textFieldNotForEmbedding", "This is a text field")); - Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); - Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); - Map nestedList1 = buildObjMap(Pair.of("nestedField", nestedList)); + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textFieldNotForEmbedding", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + HashMap nestedList1 = new HashMap<>(); + nestedList1.put("nestedField", nestedList); return new IngestDocument(nestedList1, new HashMap<>()); } } From 503189c1c4d9fc5982816d0a02d707e1ea6b3522 Mon Sep 17 00:00:00 2001 From: Will Hwang <22586574+will-hwang@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:49:34 -0800 Subject: [PATCH 09/16] add support for builder constructor in neural query builder (#1047) * add support for builder constructor in neural query builder Signed-off-by: will-hwang * create custom builder class to enforce valid neural query builder instantiation Signed-off-by: will-hwang * refactor code to remove duplicate Signed-off-by: will-hwang * include new constructor in qa packages Signed-off-by: will-hwang * refactor code to remove unnecessary code Signed-off-by: will-hwang * fix bug in neural query builder instantiation Signed-off-by: will-hwang --------- Signed-off-by: will-hwang --- .../org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java | 1 + .../org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java | 1 + .../neuralsearch/processor/NormalizationProcessorIT.java | 1 + .../opensearch/neuralsearch/processor/ScoreCombinationIT.java | 3 +++ .../opensearch/neuralsearch/query/HybridQueryBuilderTests.java | 1 + .../opensearch/neuralsearch/query/NeuralQueryBuilderTests.java | 1 + 6 files changed, 8 insertions(+) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index 52d2ee173..88af8b757 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -95,6 +95,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo .modelId(modelId) .maxDistance(100000f) .build(); + Map responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); assertNotNull(responseWithMaxScore); } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index 4dc33a15b..e2df88d6d 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -83,6 +83,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod .modelId(modelId) .k(1) .build(); + Map responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1); assertNotNull(responseWithKQuery); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 72b74b28c..3c5fc08ef 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -139,6 +139,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu .modelId(modelId) .k(5) .build(); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index 88dbfa735..feb914e30 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -250,6 +250,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -300,6 +301,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() + ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -325,6 +327,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( NeuralQueryBuilder.builder().fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).queryText(TEST_DOC_TEXT1).modelId(modelId).k(5).build() + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index c22d174ec..2385300a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -569,6 +569,7 @@ public void testStreams_whenWrittingToStream_thenSuccessful() { .queryText(QUERY_TEXT) .modelId(MODEL_ID) .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) .build(); original.add(neuralQueryBuilder); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 5c146be1c..edb9c8fd4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -572,6 +572,7 @@ private void testStreams() { NeuralQueryBuilder original = NeuralQueryBuilder.builder() .fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .boost(BOOST) From 1315ab2e29ea7f351019273045e89aa9e5186a46 Mon Sep 17 00:00:00 2001 From: Will Hwang <22586574+will-hwang@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:58:57 -0800 Subject: [PATCH 10/16] add hybrid search with rescore IT (#1066) * add hybrid search with rescore IT Signed-off-by: will-hwang * remove rescore in hybrid search IT Signed-off-by: will-hwang * remove previous version checks in build file Signed-off-by: will-hwang * removing version checks only in rolling upgrade tests Signed-off-by: will-hwang * remove newly added tests in restart test Signed-off-by: will-hwang * Revert "remove newly added tests in restart test" This reverts commit 09878318bd74320ea06b03c3b65251fbbca6ad4f. Signed-off-by: will-hwang --------- Signed-off-by: will-hwang --- .../org/opensearch/neuralsearch/bwc/HybridSearchIT.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index f64ddd455..95924112c 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -68,8 +68,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); - QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; @@ -84,10 +83,9 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); - QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault()); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); } From 597d2b481673357f526b5ca9a3d9059c717482f6 Mon Sep 17 00:00:00 2001 From: Yizhe Liu <59710443+yizheliu-amazon@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:14:24 -0800 Subject: [PATCH 11/16] Fix bug where document embedding fails to be generated due to document has dot in field name (#1062) * Fix bug where document embedding fails to be generated due to document has dot in field name Signed-off-by: Yizhe Liu * Address comments Signed-off-by: Yizhe Liu --------- Signed-off-by: Yizhe Liu --- .../neuralsearch/bwc/HybridSearchIT.java | 1 - .../processor/InferenceProcessor.java | 41 ++- .../util/ProcessorDocumentUtils.java | 24 ++ .../TextEmbeddingProcessorTests.java | 330 +++++++++++++----- 4 files changed, 300 insertions(+), 96 deletions(-) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 95924112c..44671ed4a 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -20,7 +20,6 @@ import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index ff1b663f8..6ee54afe7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -137,6 +137,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { + preprocessIngestDocument(ingestDocument); validateEmbeddingFieldsValue(ingestDocument); Map processMap = buildMapWithTargetKeys(ingestDocument); List inferenceList = createInferenceList(processMap); @@ -150,6 +151,15 @@ public void execute(IngestDocument ingestDocument, BiConsumer sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap); + unflattened.forEach(ingestDocument::setFieldValue); + sourceAndMetadataMap.keySet().removeIf(key -> key.contains(".")); + } + /** * This is the function which does actual inference work for batchExecute interface. * @param inferenceList a list of String for inference. @@ -244,12 +254,14 @@ private List getDataForInference(List i for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) { Map processMap = null; List inferenceList = null; + IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument(); try { - validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument()); - processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument()); + preprocessIngestDocument(ingestDocument); + validateEmbeddingFieldsValue(ingestDocument); + processMap = buildMapWithTargetKeys(ingestDocument); inferenceList = createInferenceList(processMap); } catch (Exception e) { - ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e); + ingestDocumentWrapper.update(ingestDocument, e); } finally { dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList)); } @@ -333,10 +345,14 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { List> list = (List>) sourceAndMetadataMap.get(parentKey); - List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); + Pair processedNestedKey = processNestedKey(nestedFieldMapEntry); + List listOfStrings = list.stream().map(x -> { + Object nestedSourceValue = x.get(processedNestedKey.getKey()); + return normalizeSourceValue(nestedSourceValue); + }).collect(Collectors.toList()); Map map = new LinkedHashMap<>(); - map.put(nestedFieldMapEntry.getKey(), listOfStrings); - buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next); + map.put(processedNestedKey.getKey(), listOfStrings); + buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next); } } treeRes.merge(parentKey, next, REMAPPING_FUNCTION); @@ -346,6 +362,17 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } } + private boolean isBlankString(Object object) { + return object instanceof String && StringUtils.isBlank((String) object); + } + + private Object normalizeSourceValue(Object value) { + if (isBlankString(value)) { + return null; + } + return value; + } + /** * Process the nested key, such as "a.b.c" to "a", "b.c" * @param nestedFieldMapEntry @@ -372,7 +399,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { ProcessorDocumentUtils.validateMapTypeValue( FIELD_MAP_FIELD, sourceAndMetadataMap, - fieldMap, + ProcessorDocumentUtils.unflattenJson(fieldMap), indexName, clusterService, environment, diff --git a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java index 0cbf4534d..6f9297e5c 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java @@ -296,6 +296,30 @@ private static void unflattenSingleItem(String key, Object value, Map child1Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); - Map child1Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child1Level2); - Map child2Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); - child2Level2.put(CHILD_2_TEXT_FIELD, TEXT_VALUE_2); - child2Level2.put(CHILD_3_TEXT_FIELD, TEXT_VALUE_3); - Map child2Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child2Level2); - Map sourceAndMetadata = Map.of( - PARENT_FIELD, - Arrays.asList(child1Level1, child2Level1), - IndexFieldMapper.NAME, - "my_index" + Map child1Level2 = buildObjMap(Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1)); + Map child1Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child1Level2)); + Map child2Level2 = buildObjMap( + Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1), + Pair.of(CHILD_2_TEXT_FIELD, TEXT_VALUE_2), + Pair.of(CHILD_3_TEXT_FIELD, TEXT_VALUE_3) + ); + Map child2Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child2Level2)); + Map sourceAndMetadata = buildObjMap( + Pair.of(PARENT_FIELD, Arrays.asList(child1Level1, child2Level1)), + Pair.of(IndexFieldMapper.NAME, "my_index") ); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - Map.of( - PARENT_FIELD, - Map.of(CHILD_FIELD_LEVEL_1, Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD))) + Map config = buildObjMap( + Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), + Pair.of( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + buildObjMap( + Pair.of( + PARENT_FIELD, + Map.of( + CHILD_FIELD_LEVEL_1, + Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD)) + ) + ) + ) ) ); TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( @@ -808,6 +814,103 @@ public void testBuildVectorOutput_withNestedMap_successful() { } } + @SneakyThrows + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withFlattenedNestedMap_successful() { + Map config = createNestedMapConfiguration(); + IngestDocument ingestDocument = createFlattenedNestedMapIngestDocument(); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + processor.preprocessIngestDocument(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f); + processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + /** + * "favorites.favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * } + * } + */ + Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); + assertNotNull(favoritesMap); + Map favorites = (Map) favoritesMap.get("favorite"); + assertNotNull(favorites); + + Map favoriteGames = (Map) favorites.get("games"); + assertNotNull(favoriteGames); + Map adventure = (Map) favoriteGames.get("adventure"); + List adventureKnnVector = (List) adventure.get("with_action_knn"); + assertNotNull(adventureKnnVector); + assertEquals(100, adventureKnnVector.size()); + for (float vector : adventureKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + + List favoriteKnnVector = (List) favorites.get("favorite_movie_knn"); + assertNotNull(favoriteKnnVector); + assertEquals(100, favoriteKnnVector.size()); + for (float vector : favoriteKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withFlattenedNestedMapAndList_successful() { + Map config = createNestedMapConfiguration(); + IngestDocument ingestDocument = createFlattenedNestedMapAndListIngestDocument(); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + processor.preprocessIngestDocument(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(3, 100, 0.0f, 1.0f); + processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + /** + * "favorites.favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : [ + * { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * }, + * { + * "adventure.action": "wukong" + * } + * ] + * } + */ + Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); + assertNotNull(favoritesMap); + Map favorite = (Map) favoritesMap.get("favorite"); + assertNotNull(favorite); + + List> favoriteGames = (List>) favorite.get("games"); + assertNotNull(favoriteGames); + for (Map favoriteGame : favoriteGames) { + Map adventure = (Map) favoriteGame.get("adventure"); + List adventureKnnVector = (List) adventure.get("with_action_knn"); + assertNotNull(adventureKnnVector); + assertEquals(100, adventureKnnVector.size()); + for (float vector : adventureKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + + List favoriteKnnVector = (List) favorite.get("favorite_movie_knn"); + assertNotNull(favoriteKnnVector); + assertEquals(100, favoriteKnnVector.size()); + for (float vector : favoriteKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + public void testBuildVectorOutput_withNestedList_successful() { Map config = createNestedListConfiguration(); IngestDocument ingestDocument = createNestedListIngestDocument(); @@ -891,8 +994,8 @@ public void testBuildVectorOutput_withNestedListLevel2_withPartialNullNestedFiel * } */ List> nestedList = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); - Map objWithNullText = buildObjMapWithSingleField("textField", null); - Map nestedObjWithNullText = buildObjMapWithSingleField("nestedField", objWithNullText); + Map objWithNullText = buildObjMap(Pair.of("textField", null)); + Map nestedObjWithNullText = buildObjMap(Pair.of("nestedField", objWithNullText)); nestedList.set(0, nestedObjWithNullText); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); @@ -1107,21 +1210,22 @@ private void assertMapWithNestedFields(Pair actual, List @SneakyThrows private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map fieldMap) { Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap); + Map config = buildObjMap( + Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), + Pair.of(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap) + ); return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } private Map createPlainStringConfiguration() { - Map config = new HashMap<>(); - config.put("oriKey1", "oriKey1_knn"); - config.put("oriKey2", "oriKey2_knn"); - config.put("oriKey3", "oriKey3_knn"); - config.put("oriKey4", "oriKey4_knn"); - config.put("oriKey5", "oriKey5_knn"); - config.put("oriKey6", "oriKey6_knn"); - return config; + return buildObjMap( + Pair.of("oriKey1", "oriKey1_knn"), + Pair.of("oriKey2", "oriKey2_knn"), + Pair.of("oriKey3", "oriKey3_knn"), + Pair.of("oriKey4", "oriKey4_knn"), + Pair.of("oriKey5", "oriKey5_knn"), + Pair.of("oriKey6", "oriKey6_knn") + ); } /** @@ -1134,24 +1238,24 @@ private Map createPlainStringConfiguration() { * } */ private Map createNestedMapConfiguration() { - Map adventureGames = new HashMap<>(); - adventureGames.put("adventure.action", "with_action_knn"); - Map favorite = new HashMap<>(); - favorite.put("favorite.movie", "favorite_movie_knn"); - favorite.put("favorite.games", adventureGames); - Map result = new HashMap<>(); - result.put("favorites", favorite); + Map adventureGames = buildObjMap(Pair.of("adventure.action", "with_action_knn")); + Map favorite = buildObjMap( + Pair.of("favorite.movie", "favorite_movie_knn"), + Pair.of("favorite.games", adventureGames) + ); + Map result = buildObjMap(Pair.of("favorites", favorite)); return result; } private IngestDocument createPlainIngestDocument() { - Map result = new HashMap<>(); - result.put("oriKey1", "oriValue1"); - result.put("oriKey2", "oriValue2"); - result.put("oriKey3", "oriValue3"); - result.put("oriKey4", "oriValue4"); - result.put("oriKey5", "oriValue5"); - result.put("oriKey6", ImmutableList.of("oriValue6", "oriValue7")); + Map result = buildObjMap( + Pair.of("oriKey1", "oriValue1"), + Pair.of("oriKey2", "oriValue2"), + Pair.of("oriKey3", "oriValue3"), + Pair.of("oriKey4", "oriValue4"), + Pair.of("oriKey5", "oriValue5"), + Pair.of("oriKey6", ImmutableList.of("oriValue6", "oriValue7")) + ); return new IngestDocument(result, new HashMap<>()); } @@ -1171,81 +1275,131 @@ private IngestDocument createPlainIngestDocument() { * } */ private IngestDocument createNestedMapIngestDocument() { - Map adventureGames = new HashMap<>(); - adventureGames.put("action", "overwatch"); - adventureGames.put("rpg", "elden ring"); - Map favGames = new HashMap<>(); - favGames.put("adventure", adventureGames); - Map favorites = new HashMap<>(); - favorites.put("movie", "matrix"); - favorites.put("games", favGames); - favorites.put("actor", "Charlie Chaplin"); - Map favorite = new HashMap<>(); - favorite.put("favorite", favorites); - Map result = new HashMap<>(); - result.put("favorites", favorite); + Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); + Map favGames = buildObjMap(Pair.of("adventure", adventureGames)); + Map favorites = buildObjMap( + Pair.of("movie", "matrix"), + Pair.of("games", favGames), + Pair.of("actor", "Charlie Chaplin") + ); + Map favorite = buildObjMap(Pair.of("favorite", favorites)); + Map result = buildObjMap(Pair.of("favorites", favorite)); + return new IngestDocument(result, new HashMap<>()); + } + + /** + * Create following document with flattened nested map + * "favorites.favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * } + * } + */ + private IngestDocument createFlattenedNestedMapIngestDocument() { + Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); + Map favGames = buildObjMap(Pair.of("adventure", adventureGames)); + Map favorites = buildObjMap( + Pair.of("movie", "matrix"), + Pair.of("games", favGames), + Pair.of("actor", "Charlie Chaplin") + ); + Map result = buildObjMap(Pair.of("favorites.favorite", favorites)); + return new IngestDocument(result, new HashMap<>()); + } + + /** + * Create following document with flattened nested map and list + * "favorites.favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : [ + * { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * }, + * { + * "adventure.action": "wukong" + * } + * ] + * } + */ + private IngestDocument createFlattenedNestedMapAndListIngestDocument() { + Map adventureGames = buildObjMap(Pair.of("action", "overwatch"), Pair.of("rpg", "elden ring")); + Map game1 = buildObjMap(Pair.of("adventure", adventureGames)); + Map game2 = buildObjMap(Pair.of("adventure.action", "wukong")); + Map favorites = buildObjMap( + Pair.of("movie", "matrix"), + Pair.of("games", Arrays.asList(game1, game2)), + Pair.of("actor", "Charlie Chaplin") + ); + Map result = buildObjMap(Pair.of("favorites.favorite", favorites)); return new IngestDocument(result, new HashMap<>()); } private Map createNestedListConfiguration() { - Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); - return buildObjMapWithSingleField("nestedField", nestedConfig); + Map nestedConfig = buildObjMap(Pair.of("textField", "vectorField")); + return buildObjMap(Pair.of("nestedField", nestedConfig)); } private Map createNestedList2LevelConfiguration() { - Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); - Map nestConfigLevel1 = buildObjMapWithSingleField("nestedField", nestedConfig); - return buildObjMapWithSingleField("nestedField", nestConfigLevel1); + Map nestedConfig = buildObjMap(Pair.of("textField", "vectorField")); + Map nestConfigLevel1 = buildObjMap(Pair.of("nestedField", nestedConfig)); + return buildObjMap(Pair.of("nestedField", nestConfigLevel1)); } private IngestDocument createNestedListIngestDocument() { - Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); - Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); - Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedObj1 = buildObjMap(Pair.of("textField", "This is a text field")); + Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); + Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument createNestedListWithNotEmbeddingFieldIngestDocument() { - Map nestedObj1 = buildObjMapWithSingleField("textFieldNotForEmbedding", "This is a text field"); - Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); - Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedObj1 = buildObjMap(Pair.of("textFieldNotForEmbedding", "This is a text field")); + Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); + Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument create2LevelNestedListIngestDocument() { - Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); - Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); - Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); - Map nestedList1 = buildObjMapWithSingleField("nestedField", nestedList); + Map nestedObj1 = buildObjMap(Pair.of("textField", "This is a text field")); + Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); + Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); + Map nestedList1 = buildObjMap(Pair.of("nestedField", nestedList)); return new IngestDocument(nestedList1, new HashMap<>()); } private IngestDocument create2LevelNestedListWithNestedFieldsIngestDocument() { - Map nestedObj1Level2 = buildObjMapWithSingleField("textField", "This is a text field"); - Map nestedObj1Level1 = buildObjMapWithSingleField("nestedField", nestedObj1Level2); + Map nestedObj1Level2 = buildObjMap(Pair.of("textField", "This is a text field")); + Map nestedObj1Level1 = buildObjMap(Pair.of("nestedField", nestedObj1Level2)); - Map nestedObj2Level2 = buildObjMapWithSingleField("textField", "This is another text field"); - Map nestedObj2Level1 = buildObjMapWithSingleField("nestedField", nestedObj2Level2); + Map nestedObj2Level2 = buildObjMap(Pair.of("textField", "This is another text field")); + Map nestedObj2Level1 = buildObjMap(Pair.of("nestedField", nestedObj2Level2)); - Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1)); + Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1))); return new IngestDocument(nestedList, new HashMap<>()); } - private Map buildObjMapWithSingleField(String fieldName, Object fieldValue) { + private Map buildObjMap(Pair... pairs) { Map objMap = new HashMap<>(); - objMap.put(fieldName, fieldValue); + for (Pair pair : pairs) { + objMap.put(pair.getKey(), pair.getValue()); + } return objMap; } private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument() { - HashMap nestedObj1 = new HashMap<>(); - nestedObj1.put("textFieldNotForEmbedding", "This is a text field"); - HashMap nestedObj2 = new HashMap<>(); - nestedObj2.put("textField", "This is another text field"); - HashMap nestedList = new HashMap<>(); - nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); - HashMap nestedList1 = new HashMap<>(); - nestedList1.put("nestedField", nestedList); + Map nestedObj1 = buildObjMap(Pair.of("textFieldNotForEmbedding", "This is a text field")); + Map nestedObj2 = buildObjMap(Pair.of("textField", "This is another text field")); + Map nestedList = buildObjMap(Pair.of("nestedField", Arrays.asList(nestedObj1, nestedObj2))); + Map nestedList1 = buildObjMap(Pair.of("nestedField", nestedList)); return new IngestDocument(nestedList1, new HashMap<>()); } } From 6be95ce589eb098621b73acf9077b707053c9e24 Mon Sep 17 00:00:00 2001 From: Yizhe Liu <59710443+yizheliu-amazon@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:58:54 -0800 Subject: [PATCH 12/16] Clean up unused validateFieldName() and use existing methods for TextEmbeddingProcessorIT (#1074) Signed-off-by: Yizhe Liu --- .../util/ProcessorDocumentUtils.java | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java index 6f9297e5c..0cbf4534d 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java @@ -296,30 +296,6 @@ private static void unflattenSingleItem(String key, Object value, Map Date: Wed, 8 Jan 2025 16:11:19 -0800 Subject: [PATCH 13/16] Correct NeuralQueryBuilder doEquals() and doHashCode(). (#1045) Signed-off-by: Bo Zhang --- .../org/opensearch/neuralsearch/query/NeuralQueryBuilder.java | 1 + .../opensearch/neuralsearch/query/HybridQueryBuilderTests.java | 1 - .../opensearch/neuralsearch/query/NeuralQueryBuilderTests.java | 1 - 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index bdaea1567..42d56b85c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -302,6 +302,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) { out.writeOptionalBoolean(this.expandNested); } + if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 2385300a6..c22d174ec 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -569,7 +569,6 @@ public void testStreams_whenWrittingToStream_thenSuccessful() { .queryText(QUERY_TEXT) .modelId(MODEL_ID) .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER) .build(); original.add(neuralQueryBuilder); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index edb9c8fd4..5c146be1c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -572,7 +572,6 @@ private void testStreams() { NeuralQueryBuilder original = NeuralQueryBuilder.builder() .fieldName(FIELD_NAME) .queryText(QUERY_TEXT) - .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .boost(BOOST) From f5ae67a79c72356a51c540d80ddd8acef30f1cb0 Mon Sep 17 00:00:00 2001 From: Isaac Johnson <114550967+Johnsonisaacn@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:44:07 -0700 Subject: [PATCH 14/16] Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) * initial commit of RRF Signed-off-by: Isaac Johnson Co-authored-by: Varun Jain Signed-off-by: Martin Gaievski --- .../processor/combination/RRFScoreCombinationTechnique.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java index 6d6c94b94..0f43688a6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -13,10 +13,10 @@ import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; -@Log4j2 /** * Abstracts combination of scores based on reciprocal rank fusion algorithm */ +@Log4j2 @ToString(onlyExplicitlyIncluded = true) public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include From 6e5596d6eb9ed471b0918a6162bc386ad59bfc9e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Dec 2024 08:59:26 -0800 Subject: [PATCH 15/16] Add integration and unit tests for missing RRF coverage (#997) * Initial unit test implementation Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan Signed-off-by: Martin Gaievski --- .../java/org/opensearch/neuralsearch/processor/RRFProcessor.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index 100cf9fc6..eaf25e3f1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -14,6 +14,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.Getter; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; From 312c7f729ac03c70bbb7161d06545399125525e2 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 23 Dec 2024 08:53:43 -0800 Subject: [PATCH 16/16] Integrate explainability for hybrid query into RRF processor (#1037) * Integrate explainability for hybrid query into RRF processor Signed-off-by: Martin Gaievski --- .../org/opensearch/neuralsearch/processor/RRFProcessor.java | 2 +- .../java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index eaf25e3f1..cf9e3b820 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -14,7 +14,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.Getter; -import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -80,6 +79,7 @@ void hybridizeScores( .combinationTechnique(combinationTechnique) .explain(explain) .pipelineProcessingContext(requestContextOptional.orElse(null)) + .searchPhaseContext(searchPhaseContext) .build(); normalizationWorkflow.execute(normalizationExecuteDTO); } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 95757e463..fb71fe0bb 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -1580,7 +1580,7 @@ protected void createRRFSearchPipeline(final String pipelineName, boolean addExp if (addExplainResponseProcessor) { builder.startArray("response_processors") .startObject() - .startObject("explanation_response_processor") + .startObject("hybrid_score_explanation") .endObject() .endObject() .endArray();