From 541b653a3150fc845512345e6c29fc32a14d9739 Mon Sep 17 00:00:00 2001 From: Wei Wang <93847013+weiwang118@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:05:08 +0800 Subject: [PATCH] Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408) * Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter Signed-off-by: Wei Wang * fix typo error in test file Signed-off-by: Wei Wang --------- Signed-off-by: Wei Wang Signed-off-by: Wei Wang <93847013+weiwang118@users.noreply.github.com> (cherry picked from commit d58d133c6edb9dfc48b5c3e507cdc21dbf0477ad) --- CHANGELOG.md | 1 + .../NativeEngineFieldVectorsWriter.java | 15 +++++--------- .../NativeEngines990KnnVectorsWriter.java | 2 +- .../NativeEngineFieldVectorsWriterTests.java | 2 ++ ...eEngines990KnnVectorsWriterFlushTests.java | 20 ++++++++++--------- ...eEngines990KnnVectorsWriterMergeTests.java | 4 +++- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b32de7bc5..33ae2063d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] - Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] - Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] +- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365] diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index 389c76e49..88eee0ee7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -14,7 +14,6 @@ import lombok.Getter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; -import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.RamUsageEstimator; @@ -43,9 +42,8 @@ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { @Getter private final Map vectors; private int lastDocID = -1; - @Getter - private final DocsWithFieldSet docsWithField; private final InfoStream infoStream; + @Getter private final FlatFieldVectorsWriter flatFieldVectorsWriter; @SuppressWarnings("unchecked") @@ -75,7 +73,6 @@ private NativeEngineFieldVectorsWriter( this.fieldInfo = fieldInfo; this.infoStream = infoStream; vectors = new HashMap<>(); - this.docsWithField = new DocsWithFieldSet(); this.flatFieldVectorsWriter = flatFieldVectorsWriter; } @@ -101,7 +98,6 @@ public void addValue(int docID, T vectorValue) throws IOException { // ensuring that vector is provided to flatFieldWriter. flatFieldVectorsWriter.addValue(docID, vectorValue); vectors.put(docID, vectorValue); - docsWithField.add(docID); lastDocID = docID; } @@ -121,10 +117,9 @@ public T copyValue(T vectorValue) { */ @Override public long ramBytesUsed() { - return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance( - Integer.class - ) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter - .ramBytesUsed(); + return SHALLOW_SIZE + flatFieldVectorsWriter.getDocsWithFieldSet().ramBytesUsed() + (long) this.vectors.size() + * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() + * RamUsageEstimator.shallowSizeOfInstance(Integer.class) + (long) vectors.size() * fieldInfo.getVectorDimension() + * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter.ramBytesUsed(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 7c8636577..3966a2c95 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -100,7 +100,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { } final Supplier> knnVectorValuesSupplier = () -> getVectorValues( vectorDataType, - field.getDocsWithField(), + field.getFlatFieldVectorsWriter().getDocsWithFieldSet(), field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 4f68a360e..707ebb2a6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -13,6 +13,7 @@ import lombok.SneakyThrows; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.InfoStream; @@ -115,6 +116,7 @@ public void testRamByteUsed_whenValidInput_thenSuccess() { Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2); FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L); + Mockito.when(mockedFlatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(new DocsWithFieldSet()); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 03d0f6160..6685e2b22 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -161,7 +161,7 @@ public void testFlush() { throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -250,7 +250,7 @@ public void testFlush_WithQuantization() { throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -352,7 +352,7 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -429,7 +429,7 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -507,7 +507,7 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -593,7 +593,7 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -683,7 +683,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -786,7 +786,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres throw new RuntimeException(e); } - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) ).thenReturn(expectedVectorValues.get(i)); @@ -848,11 +848,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map< private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class); DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); when(fieldVectorsWriter.getVectors()).thenReturn(vectors); - when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter); + when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet); return fieldVectorsWriter; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 77f3fd8ed..cdc372bda 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -370,11 +370,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map< private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class); DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); when(fieldVectorsWriter.getVectors()).thenReturn(vectors); - when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter); + when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet); return fieldVectorsWriter; } }