diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index dba0926ff..42cdccb07 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -15,21 +15,17 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; -import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -37,8 +33,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. @@ -47,15 +45,11 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); - private static final String FLUSH_OPERATION = "flush"; - private static final String MERGE_OPERATION = "merge"; - private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; - private final QuantizationService quantizationService = QuantizationService.getInstance(); public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; @@ -84,14 +78,27 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - trainAndIndex( - field.getFieldInfo(), - (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), - NativeIndexWriter::flushIndex, - field, - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS, - FLUSH_OPERATION - ); + final FieldInfo fieldInfo = field.getFieldInfo(); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = field.getVectors().size(); + if (totalLiveDocs > 0) { + final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + + StopWatch stopWatch = new StopWatch().start(); + writer.flushIndex(knnVectorValues, totalLiveDocs); + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } else { + log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + } } } @@ -100,15 +107,24 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs - trainAndIndex( - fieldInfo, - this::getKNNVectorValuesForMerge, - NativeIndexWriter::mergeIndex, - mergeState, - KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, - MERGE_OPERATION - ); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + final Supplier> knnVectorValuesSupplier = () -> getVectorValues(vectorDataType, fieldInfo, mergeState); + final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + final int totalLiveDocs = Math.toIntExact(knnVectorValues.totalLiveDocs()); + if (totalLiveDocs <= 0) { + log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); + return; + } + + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + StopWatch stopWatch = new StopWatch().start(); + + writer.mergeIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } /** @@ -157,130 +173,22 @@ public long ramBytesUsed() { .sum(); } - /** - * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field. - */ - private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); - } - - /** - * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param mergeState The {@link MergeState} representing the state of the merge operation. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field during the merge. - * @throws IOException If an I/O error occurs during the retrieval. - */ - private KNNVectorValues getKNNVectorValuesForMerge( - final VectorDataType vectorDataType, - final FieldInfo fieldInfo, - final MergeState mergeState - ) throws IOException { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); - case BYTE: - ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } - } + private QuantizationState train(final FieldInfo fieldInfo, final Supplier> knnVectorValuesSupplier) + throws IOException { - /** - * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. - * - * @param The type of vectors being processed. - */ - @FunctionalInterface - private interface IndexOperation { - void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; - } - - /** - * Functional interface representing a method that retrieves {@link KNNVectorValues} based on - * the vector data type, field information, and the merge state. - * - * @param The type of the data representing the vector (e.g., {@link VectorDataType}). - * @param The metadata about the field. - * @param The state of the merge operation. - * @param The result of the retrieval, typically {@link KNNVectorValues}. - */ - @FunctionalInterface - private interface VectorValuesRetriever { - Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; - } - - /** - * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values - * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. - * - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, - * field information, and additional context (e.g., merge state or field writer). - * @param indexOperation A functional interface that performs the indexing operation using the retrieved - * {@link KNNVectorValues}. - * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). - * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information - * @param The type of vectors being processed. - * @param The type of the context needed for retrieving the vector values. - * @throws IOException If an I/O error occurs during the processing. - */ - private void trainAndIndex( - final FieldInfo fieldInfo, - final VectorValuesRetriever> vectorValuesRetriever, - final IndexOperation indexOperation, - final C VectorProcessingContext, - final KNNGraphValue graphBuildTime, - final String operationName - ) throws IOException { - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); - QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + final QuantizationService quantizationService = QuantizationService.getInstance(); + final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - // Count the docIds - int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext)); - if (quantizationParams != null && totalLiveDocs > 0) { - initQuantizationStateWriterIfNecessary(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); - quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); - } - NativeIndexWriter writer = (quantizationParams != null) - ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) - : NativeIndexWriter.getWriter(fieldInfo, segmentWriteState); - - knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); - - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs); - long time_in_millis = stopWatch.totalTime().millis(); - graphBuildTime.incrementBy(time_in_millis); - log.warn("Graph build took " + time_in_millis + " ms for " + operationName); - } - - /** - * The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the - * vectorsValues object which you plan to use later - */ - private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { - // Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues, - // and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting - // the total live docs here. - int liveDocs = 0; - while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - liveDocs++; + if (quantizationParams != null) { + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + long totalLiveDocs = knnVectorValues.totalLiveDocs(); + if (totalLiveDocs > 0) { + initQuantizationStateWriterIfNecessary(); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + } } - return liveDocs; + return quantizationState; } private void initQuantizationStateWriterIfNecessary() throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index 771848730..4c8e2c211 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -57,15 +57,15 @@ public static QuantizationService getInstance() { * @return The {@link QuantizationState} containing the state of the trained quantizer. * @throws IOException If an I/O error occurs during the training process. */ - public QuantizationState train( - final QuantizationParams quantizationParams, - final KNNVectorValues knnVectorValues, - final long liveDocs - ) throws IOException { + public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues) + throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); // Create the training request from the vector values - KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>( + knnVectorValues, + knnVectorValues.totalLiveDocs() + ); // Train the quantizer and return the quantization state return quantizer.train(trainingRequest); diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtil.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtil.java new file mode 100644 index 000000000..d5f6b2798 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtil.java @@ -0,0 +1,268 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.Bits; +import org.opensearch.common.StopWatch; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Utility class to get merged VectorValues from MergeState + */ +@Log4j2 +public final class KNNMergeVectorValuesUtil { + + /** + * Gets list of {@link KNNVectorValuesSub} for {@link FloatVectorValues} from a merge state and returns the iterator which + * iterates over live docs from all segments while mapping docIds. + * + * @param fieldInfo + * @param mergeState + * @return List of KNNVectorSub + * @throws IOException + */ + public static FloatVectorValues mergeFloatVectorValues(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + assert fieldInfo != null && fieldInfo.hasVectorValues(); + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32"); + } + final List> subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader != null) { + FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); + if (values != null) { + final Bits liveDocs = mergeState.liveDocs[i]; + StopWatch stopWatch = new StopWatch().start(); + // live docs cardinality is needed to make sure deletedDocs are not included in final count for merge + final int liveDocsInt; + if (liveDocs != null) { + liveDocsInt = cardinality(values, liveDocs); + values = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); + } else { + liveDocsInt = Math.toIntExact(values.cost()); + } + stopWatch.stop(); + log.debug("[FloatVectorValues] Time to compute live docs cardinality {} ms", stopWatch.totalTime().millis()); + subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt)); + } + } + } + return new MergeFloat32VectorValues(subs, mergeState); + } + + /** + * Gets list of {@link KNNVectorValuesSub} for {@link ByteVectorValues} from a merge state. This can be further + * used to create an iterator for getting the docs and its vector values + * @param fieldInfo + * @param mergeState + * @return List of KNNVectorSub + * @throws IOException + */ + public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + assert fieldInfo != null && fieldInfo.hasVectorValues(); + if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) { + throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE"); + } + final List> subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader != null) { + ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); + if (values != null) { + final Bits liveDocs = mergeState.liveDocs[i]; + StopWatch stopWatch = new StopWatch().start(); + // live docs cardinality is needed to make sure deletedDocs are not included in final count for merge + final int liveDocsInt; + if (liveDocs != null) { + liveDocsInt = cardinality(values, liveDocs); + values = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); + } else { + liveDocsInt = Math.toIntExact(values.cost()); + } + stopWatch.stop(); + log.debug("[ByteVectorValues] Time to compute live docs cardinality {} ms", stopWatch.totalTime().millis()); + subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt)); + } + } + } + return new MergeByteVectorValues(subs, mergeState); + } + + private static int cardinality(final DocIdSetIterator iterator, final Bits liveDocs) throws IOException { + int count = 0; + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + if (liveDocs.get(iterator.docID())) { + count++; + } + } + return count; + } + + private static class KNNVectorValuesSub extends DocIDMerger.Sub { + final T values; + final int liveDocs; + + KNNVectorValuesSub(MergeState.DocMap docMap, T values, int liveDocs) { + super(docMap); + this.values = values; + this.liveDocs = liveDocs; + } + + @Override + public int nextDoc() throws IOException { + return values.nextDoc(); + } + } + + /** + * Iterator to get mapped docsIds from MergeState + */ + private static class MergeFloat32VectorValues extends FloatVectorValues { + + private final DocIDMerger> docIdMerger; + private final int liveDocs; + private int docId; + private final List> subs; + private KNNMergeVectorValuesUtil.KNNVectorValuesSub current; + + MergeFloat32VectorValues( + final List> subs, + final MergeState mergeState + ) throws IOException { + this.subs = subs; + this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (KNNMergeVectorValuesUtil.KNNVectorValuesSub sub : subs) { + totalSize += sub.liveDocs; + } + this.liveDocs = totalSize; + this.docId = -1; + } + + @Override + public int docID() { + return docId; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + } + return docId; + } + + @Override + public float[] vectorValue() throws IOException { + return current.values.vectorValue(); + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException("call to advance for MergeFloat32VectorValues not supported"); + } + + @Override + public int size() { + return liveDocs; + } + + @Override + public int dimension() { + return subs.get(0).values.dimension(); + } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException("call to scorer for MergeFloat32VectorValues is not supported"); + } + } + + /** + * Iterator to get mapped docsIds from MergeState + */ + @VisibleForTesting + private static class MergeByteVectorValues extends ByteVectorValues { + + private final DocIDMerger> docIdMerger; + private final int liveDocs; + private int docId; + private final List> subs; + private KNNMergeVectorValuesUtil.KNNVectorValuesSub current; + + MergeByteVectorValues(final List> subs, final MergeState mergeState) + throws IOException { + + this.subs = subs; + this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (KNNMergeVectorValuesUtil.KNNVectorValuesSub sub : subs) { + totalSize += sub.liveDocs; + } + this.liveDocs = totalSize; + this.docId = -1; + } + + @Override + public int docID() { + return docId; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + } + return docId; + } + + @Override + public byte[] vectorValue() throws IOException { + return current.values.vectorValue(); + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException("call to advance for MergeByteVectorValues not supported"); + } + + @Override + public int size() { + return liveDocs; + } + + @Override + public int dimension() { + return subs.get(0).values.dimension(); + } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException("call to scorer for MergeByteVectorValues is not supported"); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index b12395185..56ebd208f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -71,18 +71,9 @@ public int bytesPerVector() { } /** - * Returns the total live docs for KNNVectorValues. This function is broken and doesn't always give the accurate - * live docs count when iterators are {@link FloatVectorValues}, {@link ByteVectorValues}. Avoid using this iterator, - * rather use a simple function like this: - *
-     *     int liveDocs = 0;
-     *     while(vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
-     *         liveDocs++;
-     *     }
-     * 
+ * Returns the total live docs for KNNVectorValues. * @return long */ - @Deprecated public long totalLiveDocs() { return vectorValuesIterator.liveDocs(); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e217..159029a75 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,10 +5,12 @@ package org.opensearch.knn.index.vectorvalues; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.MergeState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; @@ -17,9 +19,13 @@ import java.io.IOException; import java.util.Map; +import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValuesUtil.mergeByteVectorValues; +import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValuesUtil.mergeFloatVectorValues; + /** * A factory class that provides various methods to create the {@link KNNVectorValues}. */ +@Log4j2 public final class KNNVectorValuesFactory { /** @@ -45,7 +51,37 @@ public static KNNVectorValues getVectorValues( final DocsWithFieldSet docIdWithFieldSet, final Map vectors ) { - return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues(docIdWithFieldSet, vectors)); + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<>(docIdWithFieldSet, vectors)); + } + + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo}, {@link VectorDataType} and {@link MergeState} + * Used for getting a common {@link KNNVectorValues} for all {@link org.apache.lucene.codecs.KnnVectorsReader} (segments) + * for a {@link MergeState} + * + * @param vectorDataType {@link VectorDataType} + * @param fieldInfo {@link FieldInfo} + * @param mergeState {@link MergeState} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) { + try { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + return getVectorValues(vectorDataType, mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE: + return getVectorValues(vectorDataType, mergeByteVectorValues(fieldInfo, mergeState)); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } catch (final IOException e) { + log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); + throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java index 4f1445c1c..c662be0af 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java @@ -184,5 +184,4 @@ public VectorValueExtractorStrategy getVectorExtractorStrategy() { return new VectorValueExtractorStrategy.FieldWriterIteratorVectorExtractor(); } } - } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f6..99dd91e32 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -122,5 +122,4 @@ public T extract(final VectorDataType vectorDataType, final KNNVectorValuesI ); } } - } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java new file mode 100644 index 000000000..934e97ea5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -0,0 +1,297 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final List> vectorsPerField; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Single field", List.of(Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }))), + $("Single field, no total live docs", List.of()), + $( + "Multi Field", + List.of( + Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Map.of( + 0, + new float[] { 1, 2, 3, 4 }, + 1, + new float[] { 2, 3, 4, 5 }, + 2, + new float[] { 3, 4, 5, 6 }, + 3, + new float[] { 4, 5, 6, 7 } + ) + ) + ) + ) + ); + } + + @SneakyThrows + public void testFlush() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size()) + ); + } + } + + @SneakyThrows + public void testFlush_WithQuantization() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i))).thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + assertTrue(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size() * 2) + ); + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java new file mode 100644 index 000000000..286513420 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -0,0 +1,238 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + @Mock + private FloatVectorValues floatVectorValues; + @Mock + private MergeState mergeState; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final Map mergedVectors; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Merge one field", Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 })), + $("Merge, no live docs", Map.of()) + ) + ); + } + + @SneakyThrows + public void testMerge() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + + @SneakyThrows + public void testMerge_WithQuantization() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, knnVectorValues)).thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + if (!mergedVectors.isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState), + times(2) + ); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 690391dbd..720b67fd5 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -46,7 +46,7 @@ public void setUp() throws Exception { public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +62,7 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +85,7 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -110,7 +110,7 @@ public void testTrain_fourBitQuantizer_success() throws IOException { public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +125,7 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +138,7 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -152,7 +152,7 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); } diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtilTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtilTests.java new file mode 100644 index 000000000..6dbeecb40 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValuesUtilTests.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; +import org.mockito.Mock; +import org.opensearch.knn.KNNTestCase; + +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNNMergeVectorValuesUtilTests extends KNNTestCase { + + private static final String FIELD = "field"; + + @Mock + private KnnVectorsReader knnVectorsReader1; + @Mock + private KnnVectorsReader knnVectorsReader2; + @Mock + private KnnVectorsReader knnVectorsReader3; + @Mock + private BitSet fixedBitSetLiveDocs; + @Mock + private Bits fixedBitsLiveDocs; + + @SneakyThrows + public void testFloatMergeVectorValues() { + // Given + final KnnVectorsReader[] knnVectorsReaders = { knnVectorsReader1, knnVectorsReader2, knnVectorsReader3 }; + final Bits[] liveDocs = { fixedBitSetLiveDocs, fixedBitsLiveDocs, null }; + final Map floats1 = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }); + final List floats1List = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); + final Map floats2 = Map.of(0, new float[] { 3, 4 }, 1, new float[] { 4, 6 }); + final List floats2List = List.of(new float[] { 3, 4 }, new float[] { 4, 6 }); + final Map floats3 = Map.of(0, new float[] { 1, 2 }); + final List floats3List = List.of(new float[] { 1, 2 }); + final MergeState.DocMap[] docMaps = { + (docId) -> floats1.get(docId) != null ? docId : -1, + (docId) -> floats2.get(docId) != null ? docId + 2 : -1, + (docId) -> floats3.get(docId) != null ? docId + 4 : -1 }; + + Map floats = Map.of( + 0, + new float[] { 1, 2 }, + 1, + new float[] { 2, 3 }, + 2, + new float[] { 3, 4 }, + 3, + new float[] { 4, 6 }, + 4, + new float[] { 1, 2 } + ); + final TestVectorValues.PreDefinedFloatVectorValues vectorValues1 = new TestVectorValues.PreDefinedFloatVectorValues(floats1List); + final TestVectorValues.PreDefinedFloatVectorValues vectorValues2 = new TestVectorValues.PreDefinedFloatVectorValues(floats2List); + final TestVectorValues.PreDefinedFloatVectorValues vectorValues3 = new TestVectorValues.PreDefinedFloatVectorValues(floats3List); + + FieldInfo fieldInfo = fieldInfo(0, VectorEncoding.FLOAT32); + when(knnVectorsReader1.getFloatVectorValues(FIELD)).thenReturn(vectorValues1); + when(knnVectorsReader2.getFloatVectorValues(FIELD)).thenReturn(vectorValues2); + when(knnVectorsReader3.getFloatVectorValues(FIELD)).thenReturn(vectorValues3); + when(fixedBitsLiveDocs.get(anyInt())).thenReturn(true); + when(fixedBitSetLiveDocs.get(anyInt())).thenReturn(true); + + final MergeState mergeState = mergeState(knnVectorsReaders, liveDocs, docMaps); + + // When + FloatVectorValues iterator = KNNMergeVectorValuesUtil.mergeFloatVectorValues(fieldInfo, mergeState); + + // Then + assertEquals(5, iterator.size()); + int size = 0; + while (iterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(floats.get(iterator.docID()), iterator.vectorValue(), 0.01f); + size++; + } + assertEquals(floats.size(), size); + } + + @SneakyThrows + public void testByteMergeVectorValues() { + // Given + final KnnVectorsReader[] knnVectorsReaders = { knnVectorsReader1, knnVectorsReader2, knnVectorsReader3 }; + final Bits[] liveDocs = { fixedBitSetLiveDocs, fixedBitsLiveDocs, null }; + final Map floats1 = Map.of(0, new byte[] { 1, 2 }, 1, new byte[] { 2, 3 }); + final List floats1List = List.of(new byte[] { 1, 2 }, new byte[] { 2, 3 }); + final Map floats2 = Map.of(0, new byte[] { 3, 4 }, 1, new byte[] { 4, 6 }); + final List floats2List = List.of(new byte[] { 3, 4 }, new byte[] { 4, 6 }); + final Map floats3 = Map.of(0, new byte[] { 1, 2 }); + final List floats3List = List.of(new byte[] { 1, 2 }); + final MergeState.DocMap[] docMaps = { + (docId) -> floats1.get(docId) != null ? docId : -1, + (docId) -> floats2.get(docId) != null ? docId + 2 : -1, + (docId) -> floats3.get(docId) != null ? docId + 4 : -1 }; + + Map floats = Map.of( + 0, + new byte[] { 1, 2 }, + 1, + new byte[] { 2, 3 }, + 2, + new byte[] { 3, 4 }, + 3, + new byte[] { 4, 6 }, + 4, + new byte[] { 1, 2 } + ); + final TestVectorValues.PreDefinedByteVectorValues vectorValues1 = new TestVectorValues.PreDefinedByteVectorValues(floats1List); + final TestVectorValues.PreDefinedByteVectorValues vectorValues2 = new TestVectorValues.PreDefinedByteVectorValues(floats2List); + final TestVectorValues.PreDefinedByteVectorValues vectorValues3 = new TestVectorValues.PreDefinedByteVectorValues(floats3List); + + FieldInfo fieldInfo = fieldInfo(0, VectorEncoding.BYTE); + when(knnVectorsReader1.getByteVectorValues(FIELD)).thenReturn(vectorValues1); + when(knnVectorsReader2.getByteVectorValues(FIELD)).thenReturn(vectorValues2); + when(knnVectorsReader3.getByteVectorValues(FIELD)).thenReturn(vectorValues3); + when(fixedBitsLiveDocs.get(anyInt())).thenReturn(true); + when(fixedBitSetLiveDocs.get(anyInt())).thenReturn(true); + + final MergeState mergeState = mergeState(knnVectorsReaders, liveDocs, docMaps); + + // When + ByteVectorValues iterator = KNNMergeVectorValuesUtil.mergeByteVectorValues(fieldInfo, mergeState); + + // Then + assertEquals(5, iterator.size()); + int size = 0; + while (iterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(floats.get(iterator.docID()), iterator.vectorValue()); + size++; + } + assertEquals(floats.size(), size); + } + + private MergeState mergeState(KnnVectorsReader[] knnVectorsReaders, Bits[] liveDocs, MergeState.DocMap[] docMaps) { + return new MergeState( + docMaps, + null, + null, + null, + null, + null, + null, + null, + liveDocs, + null, + null, + knnVectorsReaders, + null, + null, + null, + false + ); + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.getName()).thenReturn(FIELD); + when(fieldInfo.hasVectorValues()).thenReturn(true); + return fieldInfo; + } +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java index 0b631ab41..2e9109aaf 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java @@ -151,5 +151,4 @@ void validateVectorValues( assertEquals(bytesPerVector, vectorValues.bytesPerVector); } } - }