From ad07bf03f2e8c381c31a0f1357cdedb0eebc1356 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Fri, 25 Oct 2024 13:44:41 -0700 Subject: [PATCH 01/18] Introduce derived vector source via stored fields Generates the vector source in the source field from the KnnVectorsFormat or BVD. It does this by adding StoredFieldsFormat to our existing custom codec. Work is still WIP but rootobject is working okay. Signed-off-by: John Mazanec --- .../opensearch/knn/common/KNNConstants.java | 4 + .../org/opensearch/knn/index/KNNSettings.java | 22 +- .../DerivedSourceStoredFieldsFormat.java | 74 +++ .../DerivedSourceStoredFieldsReader.java | 58 +++ .../DerivedSourceStoredFieldsWriter.java | 121 +++++ .../codec/KNN9120Codec/KNN9120Codec.java | 20 +- .../knn/index/codec/KNNCodecVersion.java | 1 + .../DerivedSourceReaderSupplier.java | 15 + .../derivedsource/DerivedSourceReaders.java | 23 + .../DerivedSourceReadersSupplier.java | 41 ++ .../DerivedSourceStoredFieldVisitor.java | 40 ++ .../DerivedSourceVectorInjector.java | 87 ++++ .../derivedsource/ParentChildHelper.java | 39 ++ .../PerFieldDerivedVectorInjector.java | 25 + .../PerFieldDerivedVectorInjectorFactory.java | 38 ++ .../RootPerFieldDerivedVectorInjector.java | 46 ++ .../index/mapper/FlatVectorFieldMapper.java | 18 +- .../index/mapper/KNNVectorFieldMapper.java | 28 +- .../knn/index/mapper/LuceneFieldMapper.java | 23 +- .../knn/index/mapper/MethodFieldMapper.java | 18 +- .../knn/index/mapper/ModelFieldMapper.java | 17 +- .../vectorvalues/KNNVectorValuesFactory.java | 31 ++ .../DerivedVectorInjectionConsumerTests.java | 48 ++ .../mapper/KNNVectorFieldMapperTests.java | 33 +- .../opensearch/knn/integ/DerivedSourceIT.java | 442 ++++++++++++++++++ .../org/opensearch/knn/KNNRestTestCase.java | 243 +++++++++- .../org/opensearch/knn/ODFERestTestCase.java | 4 + 27 files changed, 1513 insertions(+), 46 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java create mode 100644 src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 170cfabbea..7939837aaa 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -162,4 +162,8 @@ public class KNNConstants { public static final String MODE_PARAMETER = "mode"; public static final String COMPRESSION_LEVEL_PARAMETER = "compression_level"; + + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY = "knn-derived-source-enabled"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE = "true"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_FALSE_VALUE = "false"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 8442af764c..b2f7b2f4c9 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -93,6 +93,7 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; + public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; /** * Default setting values @@ -269,6 +270,13 @@ public class KNNSettings { Setting.Property.Dynamic ); + public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( + KNN_DERIVED_SOURCE_ENABLED, + true, + IndexScope, + Setting.Property.Final + ); + /** * This setting identifies KNN index. */ @@ -518,6 +526,9 @@ private Setting getSetting(String key) { if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; } + if (KNN_DERIVED_SOURCE_ENABLED.equals(key)) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING; + } throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -543,7 +554,8 @@ public List> getSettings() { KNN_FAISS_AVX512_SPR_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, + KNN_DERIVED_SOURCE_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -581,6 +593,14 @@ public static boolean isFaissAVX2Disabled() { } } + /** + * check this index enabled/disabled derived source + * @param settings Settings + */ + public static boolean isKNNDerivedSourceEnabled(Settings settings) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING.get(settings); + } + public static boolean isFaissAVX512Disabled() { return Booleans.parseBoolean( Objects.requireNonNullElse( diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java new file mode 100644 index 0000000000..cf242b0925 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + +@AllArgsConstructor +public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { + + private final StoredFieldsFormat delegate; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + // IMPORTANT Do not rely on this for the reader, it will be null if SPI is used + private final MapperService mapperService; + + @Override + public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) + throws IOException { + List derivedVectorFields = new ArrayList<>(); + for (FieldInfo fieldInfo : fieldInfos) { + if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) { + derivedVectorFields.add(fieldInfo); + } + } + DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( + derivedSourceReadersSupplier, + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext), + derivedVectorFields + ); + return new DerivedSourceStoredFieldsReader( + delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), + derivedSourceVectorInjector + ); + } + + @Override + public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentInfo, IOContext ioContext) throws IOException { + StoredFieldsWriter delegateWriter = delegate.fieldsWriter(directory, segmentInfo, ioContext); + if (mapperService != null && KNNSettings.isKNNDerivedSourceEnabled(mapperService.getIndexSettings().getSettings())) { + List vectorFieldTypes = new ArrayList<>(); + for (MappedFieldType fieldType : mapperService.fieldTypes()) { + if (fieldType instanceof KNNVectorFieldType) { + vectorFieldTypes.add(fieldType.name()); + } + } + return new DerivedSourceStoredFieldsWriter(delegateWriter, vectorFieldTypes); + } + return delegateWriter; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java new file mode 100644 index 0000000000..233197ae6e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; + +import java.io.IOException; + +@RequiredArgsConstructor +public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { + private final StoredFieldsReader delegate; + // Given docId and source, process source + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + @Setter + private boolean shouldInject = true; + + @Override + public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { + if (shouldInject) { + delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector)); + return; + } + delegate.document(docId, storedFieldVisitor); + } + + @Override + public StoredFieldsReader clone() { + return new DerivedSourceStoredFieldsReader(delegate.clone(), derivedSourceVectorInjector); + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) { + if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) { + StoredFieldsReader storedFieldsReaderClone = storedFieldsReader.clone(); + ((DerivedSourceStoredFieldsReader) storedFieldsReaderClone).setShouldInject(false); + return storedFieldsReaderClone; + } + return storedFieldsReader; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java new file mode 100644 index 0000000000..1b3c8b3b12 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@RequiredArgsConstructor +public class DerivedSourceStoredFieldsWriter extends StoredFieldsWriter { + + private final StoredFieldsWriter delegate; + private final List vectorFieldTypes; + + @Override + public void startDocument() throws IOException { + delegate.startDocument(); + } + + @Override + public void writeField(FieldInfo fieldInfo, int i) throws IOException { + delegate.writeField(fieldInfo, i); + } + + @Override + public void writeField(FieldInfo fieldInfo, long l) throws IOException { + delegate.writeField(fieldInfo, l); + } + + @Override + public void writeField(FieldInfo fieldInfo, float v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo fieldInfo, double v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo info, DataInput value, int length) throws IOException { + delegate.writeField(info, value, length); + } + + @Override + public int merge(MergeState mergeState) throws IOException { + // We have to wrap these here to avoid storing the vectors during merge + for (int i = 0; i < mergeState.storedFieldsReaders.length; i++) { + mergeState.storedFieldsReaders[i] = DerivedSourceStoredFieldsReader.wrapForMerge(mergeState.storedFieldsReaders[i]); + } + return delegate.merge(mergeState); + } + + @Override + public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOException { + // Parse out the vectors from the source + if (Objects.equals(fieldInfo.name, SourceFieldMapper.NAME) && !vectorFieldTypes.isEmpty()) { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), + true, + MediaTypeRegistry.JSON + ); + Map filteredSource = XContentMapValues.filter(null, vectorFieldTypes.toArray(new String[0])) + .apply(mapTuple.v2()); + BytesStreamOutput bStream = new BytesStreamOutput(); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(filteredSource); + builder.close(); + BytesReference bytesReference = bStream.bytes(); + delegate.writeField(fieldInfo, bytesReference.toBytesRef()); + return; + } + delegate.writeField(fieldInfo, bytesRef); + } + + @Override + public void writeField(FieldInfo fieldInfo, String s) throws IOException { + delegate.writeField(fieldInfo, s); + } + + @Override + public void finishDocument() throws IOException { + delegate.finishDocument(); + } + + @Override + public void finish(int i) throws IOException { + delegate.finish(i); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public long ramBytesUsed() { + return delegate.ramBytesUsed(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index a370197ecc..e0d9b678b5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -11,9 +11,12 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; /** * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 @@ -23,11 +26,13 @@ public class KNN9120Codec extends FilterCodec { private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final MapperService mapperService; + /** * No arg constructor that uses Lucene99 as the delegate */ public KNN9120Codec() { - this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null); } /** @@ -38,10 +43,11 @@ public KNN9120Codec() { * @param knnVectorsFormat per field format for KnnVector */ @Builder - protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) { super(VERSION.getCodecName(), delegate); knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; + this.mapperService = mapperService; } @Override @@ -58,4 +64,14 @@ public CompoundFormat compoundFormat() { public KnnVectorsFormat knnVectorsFormat() { return perFieldKnnVectorsFormat; } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier( + (segmentReadState) -> knnVectorsFormat().fieldsReader(segmentReadState), + (segmentReadState) -> docValuesFormat().fieldsProducer(segmentReadState), + (segmentReadState) -> postingsFormat().fieldsProducer(segmentReadState) + ); + return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 3df040785b..6af6591f68 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -126,6 +126,7 @@ public enum KNNCodecVersion { (userCodec, mapperService) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .mapperService(mapperService) .build(), KNN9120Codec::new ); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java new file mode 100644 index 0000000000..123b718a46 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +@FunctionalInterface +public interface DerivedSourceReaderSupplier { + R apply(SegmentReadState segmentReadState) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java new file mode 100644 index 0000000000..5bdcc5181d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; + +/** + * Class holds the readers necessary to implement derived source. + */ +@RequiredArgsConstructor +@Getter +public class DerivedSourceReaders { + private final KnnVectorsReader knnVectorsReader; + private final DocValuesProducer docValuesProducer; + private final FieldsProducer fieldsProducer; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java new file mode 100644 index 0000000000..8e46952c2f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +/** + * Class encapsulates the suppliers to give the {@link DerivedSourceReaders} from particular formats needed to implement + * derived source. More specifically, given a {@link org.apache.lucene.index.SegmentReadState}, this class will provide + * the correct format reader for that segment. + */ +@RequiredArgsConstructor +public class DerivedSourceReadersSupplier { + private final DerivedSourceReaderSupplier knnVectorsReaderSupplier; + private final DerivedSourceReaderSupplier docValuesProducerSupplier; + private final DerivedSourceReaderSupplier fieldsProducerSupplier; + + /** + * Get the readers for the segment + * + * @param state SegmentReadState + * @return DerivedSourceReaders + * @throws IOException in case of I/O error + */ + public DerivedSourceReaders getReaders(SegmentReadState state) throws IOException { + return new DerivedSourceReaders( + knnVectorsReaderSupplier.apply(state), + docValuesProducerSupplier.apply(state), + fieldsProducerSupplier.apply(state) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java new file mode 100644 index 0000000000..41a01c15c7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.AllArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; + +/** + * Custom {@link StoredFieldVisitor} that wraps an upstream delegate visitor in order to transparently inject derived + * source vector fields into the document. After the source is modified, it is forwarded to the delegate. + */ +@AllArgsConstructor +public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor { + + private final StoredFieldVisitor delegate; + private final Integer documentId; + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + @Override + public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { + // TODO: Add skip condition here if the delegate specifies which fields are not required for source + if (fieldInfo.name.equals(SourceFieldMapper.NAME)) { + delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value)); + return; + } + delegate.binaryField(fieldInfo, value); + } + + @Override + public Status needsField(FieldInfo fieldInfo) throws IOException { + return delegate.needsField(fieldInfo); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java new file mode 100644 index 0000000000..520b0bbf56 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative + * format readers and information about the fields to inject vectors into the source. + */ +@Log4j2 +public class DerivedSourceVectorInjector { + + private final List perFieldDerivedVectorInjectors; + + /** + * Constructor for DerivedSourceVectorInjector. + * + * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param segmentReadState Segment read state + * @param fieldsToInjectVector List of fields to inject vectors into + */ + public DerivedSourceVectorInjector( + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + List fieldsToInjectVector + ) throws IOException { + DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + this.perFieldDerivedVectorInjectors = new ArrayList<>(); + for (FieldInfo fieldInfo : fieldsToInjectVector) { + this.perFieldDerivedVectorInjectors.add( + PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState) + ); + } + } + + /** + * Given a docId and the source of that doc as bytes, add all the necessary vector fields into the source. + * + * @param docId doc id of the document + * @param sourceAsBytes source of document as bytes + * @return byte array of the source with the vector fields added + * @throws IOException if there is an issue reading from the formats + */ + public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOException { + // Deserialize the source into a modifiable map + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + // Have to create a copy of the map here to ensure that is mutable + Map sourceAsMap = new HashMap<>(mapTuple.v2()); + + // For each vector field, add in the source. The per field injectors are responsible for skipping if + // the field is not present. + for (PerFieldDerivedVectorInjector vectorInjector : perFieldDerivedVectorInjectors) { + vectorInjector.inject(docId, sourceAsMap); + } + + // At this point, we can serialize the modified source map + BytesStreamOutput bStream = new BytesStreamOutput(1024); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java new file mode 100644 index 0000000000..c755e45d27 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +/** + * Helper class for working with nested fields. + */ +public class ParentChildHelper { + + /** + * Given a nested field path, return the path of the parent field. For instance if the field is "parent.to.child", + * this would return "parent.to". + * + * @param field nested field path + * @return parent field path without the child + */ + public static String getParentField(String field) { + int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } + return field.substring(0, lastDot); + } + + /** + * Given a nested field path, return the child field. For instance if the field is "parent.to.child", this would + * return "child". + * + * @param field nested field path + * @return child field path without the parent path + */ + public static String getChildField(String field) { + int lastDot = field.lastIndexOf('.'); + return field.substring(lastDot + 1); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..2467c26108 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import java.io.IOException; +import java.util.Map; + +/** + * Interface for injecting derived vectors into a source map per field. + */ +public interface PerFieldDerivedVectorInjector { + + /** + * Injects the derived vector for this field into the sourceAsMap. Implementing classes must handle the case where + * a document does not have a value for their field. + * + * @param docId Document ID + * @param sourceAsMap Source as map + * @throws IOException if there is an issue reading from the formats + */ + void inject(Integer docId, Map sourceAsMap) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java new file mode 100644 index 0000000000..f7ab3d5273 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; + +/** + * Factory for creating {@link PerFieldDerivedVectorInjector} instances. + */ +public class PerFieldDerivedVectorInjectorFactory { + + /** + * Create a {@link PerFieldDerivedVectorInjector} instance based on information in field info. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @return PerFieldDerivedVectorInjector instance + */ + public static PerFieldDerivedVectorInjector create( + FieldInfo fieldInfo, + DerivedSourceReaders derivedSourceReaders, + SegmentReadState segmentReadState + ) { + // Nested case + if (ParentChildHelper.getParentField(fieldInfo.name) != null) { + throw new IllegalArgumentException( + String.format("Field %s is a nested field. Nested fields are not supported by the derived source codec.", fieldInfo.name) + ); + } + + // Non-nested case + return new RootPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..4812dca3c6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.CheckedSupplier; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.Map; + +/** + * {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields). + */ +public class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo fieldInfo; + private final CheckedSupplier, IOException> vectorValuesSupplier; + + /** + * Constructor for RootPerFieldDerivedVectorInjector. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + */ + public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReaders derivedSourceReaders) { + this.fieldInfo = fieldInfo; + this.vectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues( + fieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + } + + @Override + public void inject(Integer docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = vectorValuesSupplier.get(); + if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { + sourceAsMap.put(fieldInfo.name, vectorValues.getVector()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 9f1ebcf018..68ea25a1fc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -14,6 +14,9 @@ import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + /** * Mapper used when you dont want to build an underlying KNN struct - you just want to * store vectors as doc values @@ -32,7 +35,8 @@ public static FlatVectorFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -49,7 +53,8 @@ public static FlatVectorFieldMapper createFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -62,7 +67,8 @@ private FlatVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -73,13 +79,17 @@ private FlatVectorFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.setDocValuesType(DocValuesType.BINARY); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.fieldType.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 99c6ebe2a1..be485847c2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -93,6 +93,7 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { */ public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); @@ -200,13 +201,15 @@ public Builder( ModelDao modelDao, Version indexCreatedVersion, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalParameters + OriginalMappingParameters originalParameters, + boolean isDerivedSourceEnabled ) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; this.knnMethodConfigContext = knnMethodConfigContext; this.originalParameters = originalParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } @Override @@ -258,7 +261,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { modelDao, indexCreatedVersion, originalParameters, - knnMethodConfigContext + knnMethodConfigContext, + isDerivedSourceEnabled ); } @@ -280,7 +284,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -301,7 +306,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { metaValue, knnMethodConfigContext, createLuceneFieldMapperInput, - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -315,7 +321,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.getValue(), hasDocValues.getValue(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -363,7 +370,8 @@ public Mapper.Builder parse(String name, Map node, ParserCont modelDaoSupplier.get(), parserContext.indexVersionCreated(), null, - null + null, + KNNSettings.isKNNDerivedSourceEnabled(parserContext.getSettings()) ); builder.parse(name, parserContext, node); builder.setOriginalParameters(new OriginalMappingParameters(builder)); @@ -569,6 +577,7 @@ static boolean useKNNMethodContextFromLegacy(Builder builder, Mapper.TypeParser. // values of KNN engine Algorithms hyperparameters. protected Version indexCreatedVersion; protected Explicit ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected boolean stored; protected boolean hasDocValues; protected VectorDataType vectorDataType; @@ -589,7 +598,8 @@ public KNNVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; @@ -599,6 +609,7 @@ public KNNVectorFieldMapper( updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; this.originalMappingParameters = originalMappingParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } public KNNVectorFieldMapper clone() { @@ -831,7 +842,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { modelDao, indexCreatedVersion, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ).init(this); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4ceb9b4b23..2abbf182d1 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -27,6 +27,8 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; @@ -48,7 +50,8 @@ static LuceneFieldMapper createFieldMapper( Map metaValue, KNNMethodConfigContext knnMethodConfigContext, CreateLuceneFieldMapperInput createLuceneFieldMapperInput, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -82,14 +85,21 @@ public Version getIndexCreatedVersion() { } ); - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext, originalMappingParameters); + return new LuceneFieldMapper( + mappedFieldType, + createLuceneFieldMapperInput, + knnMethodConfigContext, + originalMappingParameters, + isDerivedSourceEnabled + ); } private LuceneFieldMapper( final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( input.getName(), @@ -100,7 +110,8 @@ private LuceneFieldMapper( input.isStored(), input.isHasDocValues(), knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -117,6 +128,10 @@ private LuceneFieldMapper( this.vectorFieldType = null; } + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } + KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 814bc4f639..a2635b1953 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -51,7 +53,8 @@ public static MethodFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -104,7 +107,8 @@ public Version getIndexCreatedVersion() { stored, hasDocValues, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -117,7 +121,8 @@ private MethodFieldMapper( boolean stored, boolean hasDocValues, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( @@ -129,7 +134,8 @@ private MethodFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); @@ -151,7 +157,9 @@ private MethodFieldMapper( this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); - + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } try { this.fieldType.putAttribute( PARAMETERS, diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index d472090fc3..ae912aa415 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -28,6 +28,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; @@ -60,7 +62,8 @@ public static ModelFieldMapper createFieldMapper( ModelDao modelDao, Version indexCreatedVersion, OriginalMappingParameters originalMappingParameters, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + boolean isDerivedSourceEnabled ) { final KNNMethodContext knnMethodContext = originalMappingParameters.getKnnMethodContext(); @@ -134,7 +137,8 @@ private void initFromModelMetadata() { hasDocValues, modelDao, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -148,7 +152,8 @@ private ModelFieldMapper( boolean hasDocValues, ModelDao modelDao, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -159,7 +164,8 @@ private ModelFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); @@ -174,6 +180,9 @@ private ModelFieldMapper( this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e2172..699d62843b 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -72,6 +74,35 @@ public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); } + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} + * + * @param fieldInfo {@link FieldInfo} + * @param docValuesProducer {@link DocValuesProducer} + * @param knnVectorsReader {@link KnnVectorsReader} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues( + final FieldInfo fieldInfo, + final DocValuesProducer docValuesProducer, + final KnnVectorsReader knnVectorsReader + ) throws IOException { + final DocIdSetIterator docIdSetIterator; + if (fieldInfo.hasVectorValues()) { + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + docIdSetIterator = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + docIdSetIterator = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); + } + } else { + docIdSetIterator = docValuesProducer.getBinary(fieldInfo); + } + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + @SuppressWarnings("unchecked") private static KNNVectorValues getVectorValues( final VectorDataType vectorDataType, diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java new file mode 100644 index 0000000000..1db2f22bb0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.opensearch.knn.KNNTestCase; + +public class DerivedVectorInjectionConsumerTests extends KNNTestCase { + // + // @SneakyThrows + // public void testVectorInjection() { + // FloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + // List.of(new float[] { 1.0f, 2.0f }, new float[] { 2.0f, 3.0f }, new float[] { 3.0f, 4.0f }, new float[] { 4.0f, 5.0f }) + // ); + // final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + // + // final XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + // builder.field("test_text", "text-field"); + // builder.endObject(); + // + // BytesReference bytesReference = BytesReference.bytes(builder); + // toMap(bytesReference); + // + // DerivedVectorInjectionConsumer consumer = new DerivedVectorInjectionConsumer(Map.of("test_vector", () -> knnVectorValues)); + // logger.info(bytesReference.length()); + // byte[] modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); + // BytesReference modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + // toMap(modifiedBytesReference); + // + // modifiedBytes = consumer.apply(1, bytesReference.toBytesRef().bytes); + // modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + // toMap(modifiedBytesReference); + // + // modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); + // modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + // toMap(modifiedBytesReference); + // + // fail("On purpose"); + // } + // + // private void toMap(BytesReference source) { + // Tuple> mapTuple = XContentHelper.convertToMap(source, true, MediaTypeRegistry.JSON); + // logger.info(mapTuple.v2().toString()); + // } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 49b15a0f43..52ad3ded31 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -123,7 +123,8 @@ public void testBuilder_getParameters() { modelDao, CURRENT, null, - new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()) + new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()), + false ); assertEquals(10, builder.getParameters().size()); @@ -357,7 +358,7 @@ public void testTypeParser_withSpaceTypeAndMode_thenSuccess() throws IOException public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -1177,7 +1178,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1216,7 +1218,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1288,7 +1291,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1330,7 +1334,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1376,7 +1381,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1435,7 +1441,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1482,7 +1489,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1532,7 +1540,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1647,7 +1656,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithSQ_thenException() throws IOEx public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1696,7 +1705,7 @@ public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { // IllegalArgumentException should be thrown. Exception e = assertThrows(IllegalArgumentException.class, () -> { - new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null).build(builderContext); + new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null, false).build(builderContext); }); assertTrue(e.getMessage(), e.getMessage().contains("Vector field name must not include")); } diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java new file mode 100644 index 0000000000..2153d96f80 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -0,0 +1,442 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.primitives.Floats; +import lombok.SneakyThrows; +import org.junit.Ignore; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; + +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; + +/** + * Integration tests for derived source feature for vector fields. Currently, with derived source, there are + * a few gaps in functionality. + * //TODO: Dimensions: + * // 1. Data type + * // 2. Dimension + * // 3. Nested level + * // 4. Vectors per field + * // 5. Other fields + * // 6. Minimum number of values + */ +public class DerivedSourceIT extends KNNRestTestCase { + + private final static String NESTED_NAME = "test_nested"; + private final static String FIELD_NAME = "test_vector"; + private final int TEST_DIMENSION = 128; + private final int DOCS = 50; + + private static final Settings DERIVED_ENABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, true) + .build(); + private static final Settings DERIVED_DISABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, false) + .build(); + + @SneakyThrows + public void testFlatBaseCase() { + String indexNameDerivedSourceEnabled = ("enabled-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String indexNameDerivedSourceDisabled = ("disabled-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + prepareFlatIndex(indexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + prepareFlatIndex(indexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); + forceMergeKnnIndex(indexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(indexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(indexNameDerivedSourceDisabled, indexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); + refreshAllIndices(); + forceMergeKnnIndex(indexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(indexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(indexNameDerivedSourceDisabled, indexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); + } + + @SneakyThrows + public void testFlatReindex() { + String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String reindexFromEnabledToEnabledIndexName = ("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String reindexFromEnabledToDisabledIndexName = ("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String reindexFromDisabledToEnabledIndexName = ("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String reindexFromDisabledToDisabledIndexName = ("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + + prepareFlatIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + prepareFlatIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + createKnnIndex(reindexFromEnabledToEnabledIndexName, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + createKnnIndex(reindexFromEnabledToDisabledIndexName, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + createKnnIndex(reindexFromDisabledToEnabledIndexName, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + createKnnIndex(reindexFromDisabledToDisabledIndexName, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + + refreshAllIndices(); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToDisabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); + refreshAllIndices(); + + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromEnabledToDisabledIndexName); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); + } + + @SneakyThrows + public void testFlatDeletesAndUpdates() { + String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + prepareFlatIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + prepareFlatIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); + + int docWithVectorUpdate = DOCS - 4; + int docWithVectorRemoval = 1; + int docWithVectorUpdateFromAPI = 2; + int docWithUpdateByQuery = 7; + int docToDelete = 8; + int docToDeleteByQuery = 11; + + float[] updateVector = randomFloatVector(TEST_DIMENSION); + updateKnnDoc( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdate), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + updateKnnDoc( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdate), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); + setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdateFromAPI), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdateFromAPI), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + updateKnnDocByQuery( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithUpdateByQuery), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + updateKnnDocByQuery( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithUpdateByQuery), + FIELD_NAME, + Floats.asList(updateVector).toArray() + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + deleteKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDelete)); + deleteKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDelete)); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + deleteKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDeleteByQuery)); + deleteKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDeleteByQuery)); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + } + + @SneakyThrows + public void testMultiFlatFields() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject("text") + .field(TYPE, "text") + .endObject() + .endObject() + .endObject(); + + String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + + createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, builder.toString()); + createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, builder.toString()); + bulkIngestRandomVectorsWithSkipsAndMultFields( + originalIndexNameDerivedSourceEnabled, + FIELD_NAME + "1", + FIELD_NAME + "2", + "text", + DOCS, + TEST_DIMENSION, + 0.1f + ); + bulkIngestRandomVectorsWithSkipsAndMultFields( + originalIndexNameDerivedSourceDisabled, + FIELD_NAME + "1", + FIELD_NAME + "2", + "text", + DOCS, + TEST_DIMENSION, + 0.1f + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + + int docWithVectorUpdate = DOCS - 4; + int docWithUpdateByQuery = 7; + + float[] updateVector = randomFloatVector(TEST_DIMENSION); + updateKnnDoc( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdate), + FIELD_NAME + "1", + Floats.asList(updateVector).toArray() + ); + updateKnnDoc( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdate), + FIELD_NAME + "1", + Floats.asList(updateVector).toArray() + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + updateKnnDocByQuery( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithUpdateByQuery), + FIELD_NAME + "2", + Floats.asList(updateVector).toArray() + ); + updateKnnDocByQuery( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithUpdateByQuery), + FIELD_NAME + "3", + Floats.asList(updateVector).toArray() + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + } + + @Ignore + @SneakyThrows + public void testNestedSingleDocBasic() { + // For basic tests, we will have 0-5 nested documents per document + String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); + + createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); + createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); + + bulkIngestRandomVectorsWithSkipsAndNested( + originalIndexNameDerivedSourceEnabled, + NESTED_NAME + "." + FIELD_NAME, + NESTED_NAME + "." + "text", + DOCS, + TEST_DIMENSION, + 0.1f + ); + bulkIngestRandomVectorsWithSkipsAndNested( + originalIndexNameDerivedSourceDisabled, + NESTED_NAME + "." + FIELD_NAME, + NESTED_NAME + "." + "text", + DOCS, + TEST_DIMENSION, + 0.1f + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + } + + @Ignore + @SneakyThrows + public void testNestedMultiDocBasic() { + String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6).toLowerCase(Locale.ROOT)); // "test"); + // /*randomAlphaOfLength(4)).toLowerCase(Locale.ROOT)*/; + String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6).toLowerCase(Locale.ROOT)); // + ; + // //"test"); + // /*randomAlphaOfLength(4)).toLowerCase(Locale.ROOT)*/; + + createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); + createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); + + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + originalIndexNameDerivedSourceEnabled, + NESTED_NAME + "." + FIELD_NAME, + NESTED_NAME + "." + "text", + DOCS, + TEST_DIMENSION, + 0.1f, + 5 + ); + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + originalIndexNameDerivedSourceDisabled, + NESTED_NAME + "." + FIELD_NAME, + NESTED_NAME + "." + "text", + DOCS, + TEST_DIMENSION, + 0.1f, + 5 + ); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + } + + // public void testNestedReindex() { + // + // } + // + // public void testNestedUpdateAndDelete() { + // + // } + // + // public void testMultiNestedFields() { + // // TODO + // } + // + // public void testMixedNestedAndFlatFields() { + // // TODO + // } + // + // public void testFLSSupport() { + // // TODO: Security only - need to figure out how to configure this one better + // } + // + // public void testNullSet() { + // // TODO: we know this breaks + // } + + @SneakyThrows + private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { + assertTrue(indexSizeInBytes(expectedSmallerIndex) < indexSizeInBytes(expectedBiggerIndex)); + } + + @SneakyThrows + private void prepareFlatIndex(String indexName, Settings settings, String mapping) { + createKnnIndex(indexName, settings, mapping); + bulkIngestRandomVectorsWithSkips(indexName, FIELD_NAME, DOCS, TEST_DIMENSION, 0.1f); + refreshAllIndices(); + } + + private void assertDocsMatch(int docCount, String index1, String index2) { + for (int i = 0; i < docCount; i++) { + assertDocMatches(i + 1, index1, index2); + } + } + + @SneakyThrows + private void assertDocMatches(int docId, String index1, String index2) { + Map response1 = getKnnDoc(index1, String.valueOf(docId)); + Map response2 = getKnnDoc(index2, String.valueOf(docId)); + assertEquals("Docs do not match: " + docId, response1, response2); + } + + @SneakyThrows + private String createVectorNonNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + + @SneakyThrows + private String createVectorNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(NESTED_NAME) + .field(TYPE, "nested") + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 381b368c04..95f2a11a0c 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -27,6 +27,7 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.derivedsource.ParentChildHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -70,6 +71,7 @@ import java.util.Objects; import java.util.Optional; import java.util.PriorityQueue; +import java.util.Random; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -696,6 +698,28 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void addDocWithNestedNumericField(String index, String docId, String nestedFieldPath, long val) throws IOException { + String[] fieldParts = nestedFieldPath.split("\\."); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], val); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with multiple fields */ @@ -771,6 +795,76 @@ protected void updateKnnDoc(String index, String docId, String fieldName, Object assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Update a KNN Doc with a new vector for the given fieldName + */ + protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("doc") + .field(fieldName, vector) + .endObject() + .endObject(); + + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void updateKnnDocByQuery(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update_by_query?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .startObject("script") + .field("source", "ctx._source." + fieldName + " = params.newValue") + .field("lang", "painless") + .startObject("params") + .field("newValue", vector) + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void deleteKnnDocByQuery(String index, String docId) throws IOException { + // Put KNN mapping + Request request = new Request("POST", "/" + index + "/_delete_by_query?refresh"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Update a KNN Doc with a new vector for the given fieldName + */ + protected void setDocToEmpty(String index, String docId) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Delete Knn Doc */ @@ -787,6 +881,7 @@ protected void deleteKnnDoc(String index, String docId) throws IOException { */ protected Map getKnnDoc(final String index, final String docId) throws Exception { final Request request = new Request("GET", "/" + index + "/_doc/" + docId); + request.addParameter("ignore", "404"); final Response response = client().performRequest(request); final Map responseMap = createParser( @@ -795,8 +890,8 @@ protected Map getKnnDoc(final String index, final String docId) ).map(); assertNotNull(responseMap); - assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); - assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); + // assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); + // assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); final Map docMap = (Map) responseMap.get(DOCUMENT_FIELD_SOURCE); @@ -819,6 +914,22 @@ protected void updateClusterSettings(String settingKey, Object value) throws Exc assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void reindex(String source, Object destination) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("source") + .field("index", source) + .endObject() + .startObject("dest") + .field("index", destination) + .endObject() + .endObject(); + Request request = new Request("POST", "_reindex"); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Return default index settings for index creation */ @@ -896,6 +1007,21 @@ protected Response executeKnnStatRequest(List nodeIds, List stat return response; } + protected int indexSizeInBytes(String indexName) throws IOException { + Request request = new Request("GET", indexName + "/_stats" + "/store"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String responseBody = EntityUtils.toString(response.getEntity()); + + @SuppressWarnings("unchecked") + Integer sizeInBytes = (Integer) ((Map) ((Map) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("_all")).get("primaries")).get("store")).get("size_in_bytes"); + + return sizeInBytes; + } + @SneakyThrows protected void doKnnWarmup(List indices) { Response response = knnWarmup(indices); @@ -1243,15 +1369,120 @@ public Map xContentBuilderToMap(XContentBuilder xContentBuilder) } public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + // TODO: Do better on this one + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); for (int i = 0; i < numVectors; i++) { - float[] vector = new float[dimension]; - for (int j = 0; j < dimension; j++) { - vector[j] = randomFloat(); + float[] vector = vectors[i]; + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } + } + + public void bulkIngestRandomVectorsWithSkips(String indexName, String fieldName, int numVectors, int dimension, float skipProb) + throws IOException { + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector = vectors[i]; + if (random.nextFloat() > skipProb) { + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); } + } + } - addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + public void bulkIngestRandomVectorsWithSkipsAndMultFields( + String indexName, + String fieldName1, + String fieldName2, + String fieldName3, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + float[][] vectors1 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + float[][] vectors2 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 8); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector1 = vectors1[i]; + float[] vector2 = vectors2[i]; + if (random.nextFloat() > skipProb) { + addKnnDoc( + indexName, + String.valueOf(i + 1), + XContentFactory.jsonBuilder() + .startObject() + .field(fieldName1, vector1) + .field(fieldName2, vector2) + .field(fieldName3, "test-test") + .endObject() + .toString() + ); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } } + } + public void bulkIngestRandomVectorsWithSkipsAndNested( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + indexName, + nestedFieldName, + nestedNumericPath, + numVectors, + dimension, + skipProb, + 1 + ); + } + + public void bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numDocs, + int dimension, + float skipProb, + int maxDoc + ) throws IOException { + Random random = new Random(); + random.setSeed(2); + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numDocs * maxDoc, dimension, 1); + for (int i = 0; i < numDocs; i++) { + int nestedDocs = random.nextInt(maxDoc) + 1; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startArray(ParentChildHelper.getParentField(nestedFieldName)); + for (int j = 0; j < nestedDocs; j++) { + builder.startObject(); + if (random.nextFloat() > skipProb) { + builder.field(ParentChildHelper.getChildField(nestedFieldName), vectors[i + j]); + } else { + builder.field(ParentChildHelper.getChildField(nestedNumericPath), 1); + } + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + + public float[] randomFloatVector(int dimension) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomFloat(); + } + return vector; } /** diff --git a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java index efdde63c58..4fc4cc8be0 100644 --- a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java @@ -86,6 +86,10 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE return builder.build(); } + protected boolean preserveClusterUponCompletion() { + return false; + } + protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { // Similar to client configuration with OpenSearch: // https://github.com/opensearch-project/OpenSearch/blob/2.11.1/test/framework/src/main/java/org/opensearch/test/rest/OpenSearchRestTestCase.java#L841-L863 From 31de6723f96f22db1eb58f83c1d8006d9b3e2c6c Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 27 Jan 2025 21:30:39 -0800 Subject: [PATCH 02/18] Initial attempt at first level nesting Signed-off-by: John Mazanec --- .../NestedPerFieldDerivedVectorInjector.java | 192 ++++++++++++++++++ .../NestedPerFieldParentToDocIdIterator.java | 166 +++++++++++++++ .../derivedsource/ParentChildHelper.java | 12 ++ .../PerFieldDerivedVectorInjectorFactory.java | 4 +- .../opensearch/knn/integ/DerivedSourceIT.java | 7 +- .../org/opensearch/knn/KNNRestTestCase.java | 1 + 6 files changed, 374 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..f3c72a27c0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BytesRef; +import org.opensearch.index.mapper.FieldNamesFieldMapper; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +@Log4j2 +public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo childFieldInfo; + private final DerivedSourceReaders derivedSourceReaders; + private final SegmentReadState segmentReadState; + + public NestedPerFieldDerivedVectorInjector( + FieldInfo childFieldInfo, + DerivedSourceReaders derivedSourceReaders, + SegmentReadState segmentReadState + ) { + this.childFieldInfo = childFieldInfo; + this.derivedSourceReaders = derivedSourceReaders; + this.segmentReadState = segmentReadState; + } + + @Override + public void inject(Integer parentDocId, Map sourceAsMap) throws IOException { + // Setup the iterator. Return if not-relevant + String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); + String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); + if (parentFieldName == null) { + return; + } + NestedPerFieldParentToDocIdIterator nestedPerFieldParentToDocIdIterator = new NestedPerFieldParentToDocIdIterator( + childFieldInfo, + segmentReadState, + derivedSourceReaders, + parentDocId + ); + + // Initializes the parent field so that there is a map to put each of the children + Object originalParentValue = sourceAsMap.get(parentFieldName); + List> reconstructedSource; + if (originalParentValue instanceof Map) { + reconstructedSource = new ArrayList<>(List.of((Map) originalParentValue)); + } else { + reconstructedSource = (List>) originalParentValue; + } + + // Contains the positions of existing objects in the map. This is used to help figure out the best play to put back the vectors + List positions = mapObjectsToPositionInSource( + reconstructedSource, + nestedPerFieldParentToDocIdIterator.firstChild(), + parentDocId + ); + + // Finally, inject children for the document into the source. This code is non-trivial because filtering out + // the vectors during write could mean that children docs disappear from the source. So, to properly put + // everything back, we need to igure out where the existing fields in the original map to + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + int offsetPositionsIndex = 0; + while (nestedPerFieldParentToDocIdIterator.nextChild() != NO_MORE_DOCS) { + // If the child does not have a vector, vectValues advance will advance past child to the next matching + // docId. So, we need to ensure that doing this does not pass the parent docId. + if (nestedPerFieldParentToDocIdIterator.childId() > vectorValues.docId()) { + vectorValues.advance(nestedPerFieldParentToDocIdIterator.childId()); + } + if (vectorValues.docId() != nestedPerFieldParentToDocIdIterator.childId()) { + continue; + } + int docId = vectorValues.docId(); + if (docId >= parentDocId) { + break; + } + boolean isInsert = true; + int position = positions.size(); // by default we insert it at the end + for (int i = offsetPositionsIndex; i < positions.size(); i++) { + if (docId < positions.get(i)) { + position = i; + break; + } + if (docId == positions.get(i)) { + isInsert = false; + position = i; + break; + } + } + + if (isInsert) { + reconstructedSource.add(position, new HashMap<>()); + positions.add(position, docId); + } + reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); + offsetPositionsIndex = position + 1; + } + sourceAsMap.put(parentFieldName, reconstructedSource); + } + + private List mapObjectsToPositionInSource(List> originals, int firstChild, int parent) throws IOException { + List positions = new ArrayList<>(); + int offset = firstChild; + for (Map docWithFields : originals) { + int fieldMapping = docToOrdinal(docWithFields, offset, parent); + assert fieldMapping != -1; + positions.add(fieldMapping); + offset = fieldMapping + 1; + } + return positions; + } + + // Offset is first eligible object + private Integer docToOrdinal(Map doc, int offset, int parent) throws IOException { + String keyToCheck = doc.keySet().iterator().next(); + int position = getFieldsForDoc(keyToCheck, offset); + // Advancing past the parent means something went horribly wrong + assert position < parent; + return position; + } + + private int getFieldsForDoc(String fieldToMatch, int offset) throws IOException { + // TODO: Fix this up to follow + // https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java#L170-L218. + // In a perfect world, it would try everything and fall through to the field exists stuff + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo( + ParentChildHelper.constructSiblingField(childFieldInfo.name, fieldToMatch) + ); + DocIdSetIterator iterator = null; + if (fieldInfo != null) { + switch (fieldInfo.getDocValuesType()) { + case NONE: + break; + case NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getNumeric(fieldInfo); + break; + case BINARY: + iterator = derivedSourceReaders.getDocValuesProducer().getBinary(fieldInfo); + break; + case SORTED: + iterator = derivedSourceReaders.getDocValuesProducer().getSorted(fieldInfo); + break; + case SORTED_NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedNumeric(fieldInfo); + break; + case SORTED_SET: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedSet(fieldInfo); + break; + default: + throw new AssertionError(); + } + } + if (iterator != null) { + return iterator.advance(offset); + } + + Terms terms = derivedSourceReaders.getFieldsProducer().terms(FieldNamesFieldMapper.NAME); + TermsEnum fieldNameFieldsTerms = terms.iterator(); + BytesRef fieldToMatchRef = new BytesRef(fieldToMatch); + PostingsEnum postingsEnum = null; + while (fieldNameFieldsTerms.next() != null) { + BytesRef currentTerm = fieldNameFieldsTerms.term(); + if (currentTerm.bytesEquals(fieldToMatchRef)) { + postingsEnum = fieldNameFieldsTerms.postings(null); + break; + } + } + assert postingsEnum != null; + return postingsEnum.advance(offset); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java new file mode 100644 index 0000000000..6729c9d09d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Iterator over the children documents of a particular parent + */ +public class NestedPerFieldParentToDocIdIterator { + + private final FieldInfo childFieldInfo; + private final SegmentReadState segmentReadState; + private final DerivedSourceReaders derivedSourceReaders; + private final int parentDocId; + private final int previousParentDocId; + private final List children; + private int currentChild; + + /** + * + * @param childFieldInfo FieldInfo for the child field + * @param segmentReadState SegmentReadState for the segment + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @param parentDocId Parent docId of the parent + * @throws IOException if there is an error reading the parent docId + */ + public NestedPerFieldParentToDocIdIterator( + FieldInfo childFieldInfo, + SegmentReadState segmentReadState, + DerivedSourceReaders derivedSourceReaders, + int parentDocId + ) throws IOException { + this.childFieldInfo = childFieldInfo; + this.segmentReadState = segmentReadState; + this.derivedSourceReaders = derivedSourceReaders; + this.parentDocId = parentDocId; + this.previousParentDocId = previousParent(); + this.children = getChildren(); + this.currentChild = -1; + } + + /** + * For the given parent get its first child offset + * + * @return the first child offset. If there are no children, just return NO_MORE_DOCS + */ + public int firstChild() { + if (parentDocId - previousParentDocId == 1) { + return NO_MORE_DOCS; + } + return previousParentDocId + 1; + } + + /** + * Get the number of children for this parent. + * + * @return number of children for this parent + */ + public int numChildren() { + return children.size(); + } + + /** + * Get the next child for this parent + * + * @return the next child docId. If this has not been set, return -1. If there are no more children, return + * NO_MORE_DOCS + */ + public int nextChild() { + currentChild++; + if (currentChild >= children.size()) { + return NO_MORE_DOCS; + } + return children.get(currentChild); + } + + /** + * Get the current child for this parent + * + * @return the current child docId. If this has not been set, return -1 + */ + public int childId() { + return children.get(currentChild); + } + + /** + * For parentDocId of this class, find the one just before it to be used for matching children. + * + * @return the parent docId just before the parentDocId. -1 if none exist + * @throws IOException if there is an error reading the parent docId + */ + private int previousParent() throws IOException { + // TODO: In the future this needs to be generalized to handle multiple levels of nesting + // For now, for non-nested docs, the primary_term field can be used to identify root level docs. For reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/SeqNoPrimaryTermPhase.java#L72 + // https://github.com/opensearch-project/OpenSearch/blob/3032bef54d502836789ea438f464ae0b1ba978b2/server/src/main/java/org/opensearch/index/mapper/SeqNoFieldMapper.java#L206-L230 + // We use it here to identify the previous parent to the current parent to get a range on the children documents + FieldInfo seqTermsFieldInfo = segmentReadState.fieldInfos.fieldInfo("_primary_term"); + NumericDocValues numericDocValues = derivedSourceReaders.getDocValuesProducer().getNumeric(seqTermsFieldInfo); + int previousParentDocId = -1; + while (numericDocValues.nextDoc() != NO_MORE_DOCS) { + if (numericDocValues.docID() >= parentDocId) { + break; + } + previousParentDocId = numericDocValues.docID(); + } + return previousParentDocId; + } + + /** + * Get all the children that match the parent path for the _nested_field + * + * @return list of children that match the parent path + * @throws IOException if there is an error reading the children + */ + private List getChildren() throws IOException { + // First, we need to get the currect PostingsEnum for the key as _nested_path and the value the actual parent + // path. + String childField = childFieldInfo.name; + String parentField = ParentChildHelper.getParentField(childField); + + Terms terms = derivedSourceReaders.getFieldsProducer().terms("_nested_path"); + TermsEnum nestedFieldsTerms = terms.iterator(); + BytesRef childPathRef = new BytesRef(parentField); + PostingsEnum postingsEnum = null; + while (nestedFieldsTerms.next() != null) { + BytesRef currentTerm = nestedFieldsTerms.term(); + if (currentTerm.bytesEquals(childPathRef)) { + postingsEnum = nestedFieldsTerms.postings(null); + break; + } + } + + // Next, get all the children that match this parent path. If none were found, return an empty list + if (postingsEnum == null) { + return Collections.emptyList(); + } + List children = new ArrayList<>(); + postingsEnum.advance(previousParentDocId + 1); + while (postingsEnum.docID() != NO_MORE_DOCS && postingsEnum.docID() < parentDocId) { + if (postingsEnum.freq() > 0) { + children.add(postingsEnum.docID()); + } + postingsEnum.nextDoc(); + } + + return children; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java index c755e45d27..f47fbc9d09 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -36,4 +36,16 @@ public static String getChildField(String field) { int lastDot = field.lastIndexOf('.'); return field.substring(lastDot + 1); } + + /** + * Construct a sibling field path. For instance, if the field is "parent.to.child" and the sibling is "sibling", this + * would return "parent.to.sibling". + * + * @param field nested field path + * @param sibling sibling field + * @return sibling field path + */ + public static String constructSiblingField(String field, String sibling) { + return getParentField(field) + "." + sibling; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java index f7ab3d5273..c0a1e0da00 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java @@ -27,9 +27,7 @@ public static PerFieldDerivedVectorInjector create( ) { // Nested case if (ParentChildHelper.getParentField(fieldInfo.name) != null) { - throw new IllegalArgumentException( - String.format("Field %s is a nested field. Nested fields are not supported by the derived source codec.", fieldInfo.name) - ); + return new NestedPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders, segmentReadState); } // Non-nested case diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 2153d96f80..0fb5335d40 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -7,7 +7,6 @@ import com.google.common.primitives.Floats; import lombok.SneakyThrows; -import org.junit.Ignore; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; @@ -273,7 +272,6 @@ public void testMultiFlatFields() { assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); } - @Ignore @SneakyThrows public void testNestedSingleDocBasic() { // For basic tests, we will have 0-5 nested documents per document @@ -300,19 +298,18 @@ public void testNestedSingleDocBasic() { 0.1f ); refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); refreshAllIndices(); assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); refreshAllIndices(); forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); refreshAllIndices(); } - @Ignore @SneakyThrows public void testNestedMultiDocBasic() { String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6).toLowerCase(Locale.ROOT)); // "test"); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 95f2a11a0c..aef5abc646 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1473,6 +1473,7 @@ public void bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( } builder.endArray(); builder.endObject(); + // log.info(builder.toString()); addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); } } From 927b6deb8fe6c9a1010872f86c0f5b9307c38087 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 27 Jan 2025 23:00:20 -0800 Subject: [PATCH 03/18] Add the include/exclude optimization Signed-off-by: John Mazanec --- CHANGELOG.md | 2 +- .../DerivedSourceStoredFieldsReader.java | 18 +++++++++- .../DerivedSourceStoredFieldVisitor.java | 1 - .../DerivedSourceVectorInjector.java | 35 +++++++++++++++++++ .../NestedPerFieldDerivedVectorInjector.java | 12 ++----- .../NestedPerFieldParentToDocIdIterator.java | 9 ----- 6 files changed, 55 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2508a05ed..1385fe1fbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add a new build mode, `FAISS_OPT_LEVEL=avx512_spr`, which enables the use of advanced AVX-512 instructions introduced with Intel(R) Sapphire Rapids (#2404)[https://github.com/opensearch-project/k-NN/pull/2404] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] - Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345] - +- Add derived source feature for vector fields (#2449)[https://github.com/opensearch-project/k-NN/pull/2449] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index 233197ae6e..ef9eba126d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -9,6 +9,7 @@ import lombok.Setter; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.index.fieldvisitor.FieldsVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; @@ -25,7 +26,15 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { @Override public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { - if (shouldInject) { + // If the visitor has explicitly indicated it does not need the fields, we should not inject them + boolean isVisitorNeedFields = true; + if (storedFieldVisitor instanceof FieldsVisitor) { + isVisitorNeedFields = derivedSourceVectorInjector.shouldInject( + ((FieldsVisitor) storedFieldVisitor).includes(), + ((FieldsVisitor) storedFieldVisitor).excludes() + ); + } + if (shouldInject && isVisitorNeedFields) { delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector)); return; } @@ -47,6 +56,13 @@ public void close() throws IOException { delegate.close(); } + /** + * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, + * on merge we will end up just writing the source to disk + * + * @param storedFieldsReader stored fields reader to wrap + * @return wrapped stored fields reader + */ public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) { if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) { StoredFieldsReader storedFieldsReaderClone = storedFieldsReader.clone(); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java index 41a01c15c7..9610eff683 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java @@ -25,7 +25,6 @@ public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor { @Override public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { - // TODO: Add skip condition here if the delegate specifies which fields are not required for source if (fieldInfo.name.equals(SourceFieldMapper.NAME)) { delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value)); return; diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index 520b0bbf56..c7789a1cd9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -20,8 +20,10 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; /** * This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative @@ -31,6 +33,7 @@ public class DerivedSourceVectorInjector { private final List perFieldDerivedVectorInjectors; + private final Set fieldNames; /** * Constructor for DerivedSourceVectorInjector. @@ -46,10 +49,12 @@ public DerivedSourceVectorInjector( ) throws IOException { DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); this.perFieldDerivedVectorInjectors = new ArrayList<>(); + this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { this.perFieldDerivedVectorInjectors.add( PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState) ); + this.fieldNames.add(fieldInfo.name); } } @@ -84,4 +89,34 @@ public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOExcept builder.close(); return BytesReference.toBytes(BytesReference.bytes(builder)); } + + /** + * Whether or not to inject vectors based on what fields are explicitly required + * + * @param includes List of fields that are required to be injected + * @param excludes List of fields that are not required to be injected + * @return true if vectors should be injected, false otherwise + */ + public boolean shouldInject(String[] includes, String[] excludes) { + // If any of the vector fields are explicitly required we should inject + if (includes != null) { + for (String includedField : includes) { + if (fieldNames.contains(includedField)) { + return true; + } + } + } + + // If all of the vector fields are explicitly excluded we should not inject + if (excludes != null) { + int excludedVectorFieldCount = 0; + for (String excludedField : excludes) { + if (fieldNames.contains(excludedField)) { + excludedVectorFieldCount++; + } + } + return excludedVectorFieldCount >= fieldNames.size(); + } + return true; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index f3c72a27c0..6dce3f12ae 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.derivedsource; +import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.PostingsEnum; @@ -26,22 +27,13 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @Log4j2 +@AllArgsConstructor public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { private final FieldInfo childFieldInfo; private final DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; - public NestedPerFieldDerivedVectorInjector( - FieldInfo childFieldInfo, - DerivedSourceReaders derivedSourceReaders, - SegmentReadState segmentReadState - ) { - this.childFieldInfo = childFieldInfo; - this.derivedSourceReaders = derivedSourceReaders; - this.segmentReadState = segmentReadState; - } - @Override public void inject(Integer parentDocId, Map sourceAsMap) throws IOException { // Setup the iterator. Return if not-relevant diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java index 6729c9d09d..2d1dffa7fc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -68,15 +68,6 @@ public int firstChild() { return previousParentDocId + 1; } - /** - * Get the number of children for this parent. - * - * @return number of children for this parent - */ - public int numChildren() { - return children.size(); - } - /** * Get the next child for this parent * From 58164bb4299a89a8cc8e6785f8d544c2ab684162 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 13:31:02 -0800 Subject: [PATCH 04/18] Fix bugs and improve tests Signed-off-by: John Mazanec --- .../DerivedSourceVectorInjector.java | 8 +- .../opensearch/knn/integ/DerivedSourceIT.java | 914 ++++++++++++------ .../org/opensearch/knn/KNNRestTestCase.java | 57 +- 3 files changed, 679 insertions(+), 300 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index c7789a1cd9..7218735ec0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -11,6 +11,7 @@ import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -99,7 +100,7 @@ public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOExcept */ public boolean shouldInject(String[] includes, String[] excludes) { // If any of the vector fields are explicitly required we should inject - if (includes != null) { + if (includes != null && includes != Strings.EMPTY_ARRAY) { for (String includedField : includes) { if (fieldNames.contains(includedField)) { return true; @@ -108,14 +109,15 @@ public boolean shouldInject(String[] includes, String[] excludes) { } // If all of the vector fields are explicitly excluded we should not inject - if (excludes != null) { + if (excludes != null && excludes != Strings.EMPTY_ARRAY) { int excludedVectorFieldCount = 0; for (String excludedField : excludes) { if (fieldNames.contains(excludedField)) { excludedVectorFieldCount++; } } - return excludedVectorFieldCount >= fieldNames.size(); + // Inject if we havent excluded all of the fields + return excludedVectorFieldCount < fieldNames.size(); } return true; } diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 0fb5335d40..424c05c842 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -6,13 +6,18 @@ package org.opensearch.knn.integ; import com.google.common.primitives.Floats; +import lombok.Builder; +import lombok.Data; import lombok.SneakyThrows; +import org.opensearch.common.CheckedConsumer; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.KNNSettings; +import java.io.IOException; +import java.util.List; import java.util.Locale; import java.util.Map; @@ -23,13 +28,6 @@ /** * Integration tests for derived source feature for vector fields. Currently, with derived source, there are * a few gaps in functionality. - * //TODO: Dimensions: - * // 1. Data type - * // 2. Dimension - * // 3. Nested level - * // 4. Vectors per field - * // 5. Other fields - * // 6. Minimum number of values */ public class DerivedSourceIT extends KNNRestTestCase { @@ -51,134 +49,108 @@ public class DerivedSourceIT extends KNNRestTestCase { .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, false) .build(); + /** + * Testing flat, single field base case with index configuration: + * { + * "settings": { + * "index.knn" true, + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + * Comparing to the baseline: + * { + * "settings": { + * "index.knn" true, + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + */ @SneakyThrows public void testFlatBaseCase() { - String indexNameDerivedSourceEnabled = ("enabled-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String indexNameDerivedSourceDisabled = ("disabled-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - prepareFlatIndex(indexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - prepareFlatIndex(indexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); - forceMergeKnnIndex(indexNameDerivedSourceEnabled, 10); - forceMergeKnnIndex(indexNameDerivedSourceDisabled, 10); - refreshAllIndices(); - assertIndexBigger(indexNameDerivedSourceDisabled, indexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); - refreshAllIndices(); - forceMergeKnnIndex(indexNameDerivedSourceEnabled, 1); - forceMergeKnnIndex(indexNameDerivedSourceDisabled, 1); - refreshAllIndices(); - assertIndexBigger(indexNameDerivedSourceDisabled, indexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, indexNameDerivedSourceEnabled, indexNameDerivedSourceDisabled); - } - - @SneakyThrows - public void testFlatReindex() { - String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String reindexFromEnabledToEnabledIndexName = ("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String reindexFromEnabledToDisabledIndexName = ("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String reindexFromDisabledToEnabledIndexName = ("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String reindexFromDisabledToDisabledIndexName = ("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - - prepareFlatIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - prepareFlatIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - createKnnIndex(reindexFromEnabledToEnabledIndexName, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - createKnnIndex(reindexFromEnabledToDisabledIndexName, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - createKnnIndex(reindexFromDisabledToEnabledIndexName, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - createKnnIndex(reindexFromDisabledToDisabledIndexName, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - - refreshAllIndices(); - reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToEnabledIndexName); - reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToDisabledIndexName); - reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); - reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); - refreshAllIndices(); - - assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); - assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); - assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); - assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); - - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromEnabledToDisabledIndexName); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); - } - - @SneakyThrows - public void testFlatDeletesAndUpdates() { - String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - prepareFlatIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - prepareFlatIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNonNestedMappings(TEST_DIMENSION)); - - int docWithVectorUpdate = DOCS - 4; - int docWithVectorRemoval = 1; - int docWithVectorUpdateFromAPI = 2; - int docWithUpdateByQuery = 7; - int docToDelete = 8; - int docToDeleteByQuery = 11; - - float[] updateVector = randomFloatVector(TEST_DIMENSION); - updateKnnDoc( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithVectorUpdate), - FIELD_NAME, - Floats.asList(updateVector).toArray() - ); - updateKnnDoc( - originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithVectorUpdate), - FIELD_NAME, - Floats.asList(updateVector).toArray() - ); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - - setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); - setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - - updateKnnDocWithUpdateAPI( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithVectorUpdateFromAPI), - FIELD_NAME, - Floats.asList(updateVector).toArray() - ); - updateKnnDocWithUpdateAPI( - originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithVectorUpdateFromAPI), - FIELD_NAME, - Floats.asList(updateVector).toArray() - ); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() - updateKnnDocByQuery( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithUpdateByQuery), - FIELD_NAME, - Floats.asList(updateVector).toArray() - ); - updateKnnDocByQuery( - originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithUpdateByQuery), - FIELD_NAME, - Floats.asList(updateVector).toArray() ); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - - deleteKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDelete)); - deleteKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDelete)); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - - deleteKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDeleteByQuery)); - deleteKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDeleteByQuery)); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + testDerivedSourceE2E(indexConfigContexts); } @SneakyThrows @@ -199,194 +171,588 @@ public void testMultiFlatFields() { .endObject() .endObject() .endObject(); + String multiFieldMapping = builder.toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() - String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - - createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, builder.toString()); - createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, builder.toString()); - bulkIngestRandomVectorsWithSkipsAndMultFields( - originalIndexNameDerivedSourceEnabled, - FIELD_NAME + "1", - FIELD_NAME + "2", - "text", - DOCS, - TEST_DIMENSION, - 0.1f ); - bulkIngestRandomVectorsWithSkipsAndMultFields( - originalIndexNameDerivedSourceDisabled, - FIELD_NAME + "1", - FIELD_NAME + "2", - "text", - DOCS, - TEST_DIMENSION, - 0.1f + testDerivedSourceE2E(indexConfigContexts); + } + + public void testNestedSingleDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + @SneakyThrows + public void testNestedMultiDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + ); + testDerivedSourceE2E(indexConfigContexts); + } + + // TODO Test configurations + // 1. Baseline flat + // 2. Multi-field flat + // 3. Nested single doc + // 4. Nested multi-docs + // 5. Nested multi-fields multi docs + // 6. All types of fields + // 7. FLS index + // 8. Object fields + + // We need to write a single method that will run through all the different possible combinations and + // abstact when necessary. + @SneakyThrows + private void testDerivedSourceE2E(List indexConfigContexts) { + // Make sure there are 6 + assertEquals(6, indexConfigContexts.size()); + + // Prepare the indices by creating them and ingesting data into them + prepareOriginalIndices(indexConfigContexts); + + // Merging + testMerging(indexConfigContexts); + + // Update + // TODO: Skipping nested for now + if (indexConfigContexts.get(0).isNested == false) { + testUpdate(indexConfigContexts); + } + + // Delete + testDelete(indexConfigContexts); + + // Search + testSearch(indexConfigContexts); + + // Reindex + testReindex(indexConfigContexts); + } + + @SneakyThrows + private void prepareOriginalIndices(List indexConfigContexts) { + assertEquals(6, indexConfigContexts.size()); + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + createKnnIndex(derivedSourceEnabledContext.indexName, derivedSourceEnabledContext.settings, derivedSourceEnabledContext.mapping); + createKnnIndex(derivedSourceDisabledContext.indexName, derivedSourceDisabledContext.settings, derivedSourceDisabledContext.mapping); + derivedSourceEnabledContext.indexIngestor.accept(derivedSourceEnabledContext); + derivedSourceDisabledContext.indexIngestor.accept(derivedSourceDisabledContext); refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + derivedSourceDisabledContext.indexName, + derivedSourceEnabledContext.indexName + ); + } + + @SneakyThrows + private void testMerging(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); refreshAllIndices(); assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); refreshAllIndices(); forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); refreshAllIndices(); assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceEnabled, originalIndexNameDerivedSourceDisabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + @SneakyThrows + private void testUpdate(List indexConfigContexts) { + // Random variables int docWithVectorUpdate = DOCS - 4; + int docWithVectorRemoval = 1; + int docWithVectorUpdateFromAPI = 2; int docWithUpdateByQuery = 7; - float[] updateVector = randomFloatVector(TEST_DIMENSION); - updateKnnDoc( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithVectorUpdate), - FIELD_NAME + "1", - Floats.asList(updateVector).toArray() - ); - updateKnnDoc( + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + float[] updateVector = randomFloatVector(derivedSourceDisabledContext.dimension); + + // Update via POST //_doc/ + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithVectorUpdate), - FIELD_NAME + "1", - Floats.asList(updateVector).toArray() + originalIndexNameDerivedSourceEnabled ); - refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - updateKnnDocByQuery( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithUpdateByQuery), - FIELD_NAME + "2", - Floats.asList(updateVector).toArray() + setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); + setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled ); - updateKnnDocByQuery( + + // Use update API + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithUpdateByQuery), - FIELD_NAME + "3", - Floats.asList(updateVector).toArray() + originalIndexNameDerivedSourceEnabled ); + + // Update by query + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); } @SneakyThrows - public void testNestedSingleDocBasic() { - // For basic tests, we will have 0-5 nested documents per document - String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT); - - createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); - createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); - - bulkIngestRandomVectorsWithSkipsAndNested( - originalIndexNameDerivedSourceEnabled, - NESTED_NAME + "." + FIELD_NAME, - NESTED_NAME + "." + "text", - DOCS, - TEST_DIMENSION, - 0.1f - ); - bulkIngestRandomVectorsWithSkipsAndNested( + private void testDelete(List indexConfigContexts) { + int docToDelete = 8; + int docToDeleteByQuery = 11; + + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + + // Delete by API + deleteKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDelete)); + deleteKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDelete)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, originalIndexNameDerivedSourceDisabled, - NESTED_NAME + "." + FIELD_NAME, - NESTED_NAME + "." + "text", - DOCS, - TEST_DIMENSION, - 0.1f + originalIndexNameDerivedSourceEnabled ); + + // Delete by query + deleteKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDeleteByQuery)); + deleteKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDeleteByQuery)); refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); - forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); - refreshAllIndices(); - assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - refreshAllIndices(); - forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); - forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); - refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); } @SneakyThrows - public void testNestedMultiDocBasic() { - String originalIndexNameDerivedSourceEnabled = ("original-enable-" + randomAlphaOfLength(6).toLowerCase(Locale.ROOT)); // "test"); - // /*randomAlphaOfLength(4)).toLowerCase(Locale.ROOT)*/; - String originalIndexNameDerivedSourceDisabled = ("original-disable-" + randomAlphaOfLength(6).toLowerCase(Locale.ROOT)); // + ; - // //"test"); - // /*randomAlphaOfLength(4)).toLowerCase(Locale.ROOT)*/; - - createKnnIndex(originalIndexNameDerivedSourceEnabled, DERIVED_ENABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); - createKnnIndex(originalIndexNameDerivedSourceDisabled, DERIVED_DISABLED_SETTINGS, createVectorNestedMappings(TEST_DIMENSION)); - - bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( - originalIndexNameDerivedSourceEnabled, - NESTED_NAME + "." + FIELD_NAME, - NESTED_NAME + "." + "text", - DOCS, - TEST_DIMENSION, - 0.1f, - 5 + private void testSearch(List indexConfigContexts) { + // TODO + + } + + @SneakyThrows + private void testReindex(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + IndexConfigContext reindexFromEnabledToEnabledContext = indexConfigContexts.get(2); + IndexConfigContext reindexFromEnabledToDisabledContext = indexConfigContexts.get(3); + IndexConfigContext reindexFromDisabledToEnabledContext = indexConfigContexts.get(4); + IndexConfigContext reindexFromDisabledToDisabledContext = indexConfigContexts.get(5); + + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + String reindexFromEnabledToEnabledIndexName = reindexFromEnabledToEnabledContext.indexName; + String reindexFromEnabledToDisabledIndexName = reindexFromEnabledToDisabledContext.indexName; + String reindexFromDisabledToEnabledIndexName = reindexFromDisabledToEnabledContext.indexName; + String reindexFromDisabledToDisabledIndexName = reindexFromDisabledToDisabledContext.indexName; + + createKnnIndex( + reindexFromEnabledToEnabledIndexName, + reindexFromEnabledToEnabledContext.getSettings(), + reindexFromEnabledToEnabledContext.getMapping() ); - bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( - originalIndexNameDerivedSourceDisabled, - NESTED_NAME + "." + FIELD_NAME, - NESTED_NAME + "." + "text", - DOCS, - TEST_DIMENSION, - 0.1f, - 5 + createKnnIndex( + reindexFromEnabledToDisabledIndexName, + reindexFromEnabledToDisabledContext.getSettings(), + reindexFromEnabledToDisabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToEnabledIndexName, + reindexFromDisabledToEnabledContext.getSettings(), + reindexFromDisabledToEnabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToDisabledIndexName, + reindexFromDisabledToDisabledContext.getSettings(), + reindexFromDisabledToDisabledContext.getMapping() ); refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); - forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); - refreshAllIndices(); - assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToDisabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); + + // Need to forcemerge before comparison refreshAllIndices(); forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); refreshAllIndices(); - assertDocsMatch(DOCS, originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - } + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); - // public void testNestedReindex() { - // - // } - // - // public void testNestedUpdateAndDelete() { - // - // } - // - // public void testMultiNestedFields() { - // // TODO - // } - // - // public void testMixedNestedAndFlatFields() { - // // TODO - // } - // - // public void testFLSSupport() { - // // TODO: Security only - need to figure out how to configure this one better - // } - // - // public void testNullSet() { - // // TODO: we know this breaks - // } + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToDisabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToDisabledIndexName + ); + } - @SneakyThrows - private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { - assertTrue(indexSizeInBytes(expectedSmallerIndex) < indexSizeInBytes(expectedBiggerIndex)); + @Builder + @Data + private static class IndexConfigContext { + String indexName; + List vectorFieldNames; + int dimension; + Settings settings; + String mapping; + boolean isNested; + int docCount; + CheckedConsumer indexIngestor; } @SneakyThrows - private void prepareFlatIndex(String indexName, Settings settings, String mapping) { - createKnnIndex(indexName, settings, mapping); - bulkIngestRandomVectorsWithSkips(indexName, FIELD_NAME, DOCS, TEST_DIMENSION, 0.1f); - refreshAllIndices(); + private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { + int expectedSmaller = indexSizeInBytes(expectedSmallerIndex); + int expectedBigger = indexSizeInBytes(expectedBiggerIndex); + assertTrue( + "Expected smaller index " + expectedSmaller + " was bigger than the expected bigger index:" + expectedBigger, + expectedSmaller < expectedBigger + ); } private void assertDocsMatch(int docCount, String index1, String index2) { diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index aef5abc646..b134802910 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -786,8 +786,14 @@ protected void addDocWithBinaryField(String index, String docId, String fieldNam */ protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject(); request.setJsonEntity(builder.toString()); @@ -800,16 +806,15 @@ protected void updateKnnDoc(String index, String docId, String fieldName, Object */ protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); - - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("doc") - .field(fieldName, vector) - .endObject() - .endObject(); - + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("doc"); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject().endObject(); request.setJsonEntity(builder.toString()); - Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } @@ -1408,18 +1413,24 @@ public void bulkIngestRandomVectorsWithSkipsAndMultFields( for (int i = 0; i < numVectors; i++) { float[] vector1 = vectors1[i]; float[] vector2 = vectors2[i]; - if (random.nextFloat() > skipProb) { - addKnnDoc( - indexName, - String.valueOf(i + 1), - XContentFactory.jsonBuilder() - .startObject() - .field(fieldName1, vector1) - .field(fieldName2, vector2) - .field(fieldName3, "test-test") - .endObject() - .toString() - ); + + boolean includeFieldOne = random.nextFloat() > skipProb; + boolean includeFieldTwo = random.nextFloat() > skipProb; + boolean includeFieldThree = random.nextFloat() > skipProb; + + if (includeFieldOne || includeFieldTwo || includeFieldThree) { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + if (includeFieldOne) { + xContentBuilder.field(fieldName1, vector1); + } + if (includeFieldTwo) { + xContentBuilder.field(fieldName2, vector2); + } + if (includeFieldThree) { + xContentBuilder.field(fieldName3, "test-test"); + } + xContentBuilder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), xContentBuilder.toString()); } else { addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); } From 9c8394e715d9647a3521cd0f70d94cf29476120b Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 13:35:04 -0800 Subject: [PATCH 05/18] Remove old test Signed-off-by: John Mazanec --- .../DerivedVectorInjectionConsumerTests.java | 48 ------------------- 1 file changed, 48 deletions(-) delete mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java deleted file mode 100644 index 1db2f22bb0..0000000000 --- a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedVectorInjectionConsumerTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.KNN9120Codec; - -import org.opensearch.knn.KNNTestCase; - -public class DerivedVectorInjectionConsumerTests extends KNNTestCase { - // - // @SneakyThrows - // public void testVectorInjection() { - // FloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - // List.of(new float[] { 1.0f, 2.0f }, new float[] { 2.0f, 3.0f }, new float[] { 3.0f, 4.0f }, new float[] { 4.0f, 5.0f }) - // ); - // final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); - // - // final XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - // builder.field("test_text", "text-field"); - // builder.endObject(); - // - // BytesReference bytesReference = BytesReference.bytes(builder); - // toMap(bytesReference); - // - // DerivedVectorInjectionConsumer consumer = new DerivedVectorInjectionConsumer(Map.of("test_vector", () -> knnVectorValues)); - // logger.info(bytesReference.length()); - // byte[] modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); - // BytesReference modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); - // toMap(modifiedBytesReference); - // - // modifiedBytes = consumer.apply(1, bytesReference.toBytesRef().bytes); - // modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); - // toMap(modifiedBytesReference); - // - // modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); - // modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); - // toMap(modifiedBytesReference); - // - // fail("On purpose"); - // } - // - // private void toMap(BytesReference source) { - // Tuple> mapTuple = XContentHelper.convertToMap(source, true, MediaTypeRegistry.JSON); - // logger.info(mapTuple.v2().toString()); - // } - -} From c8ff87869205596b7707aeb58b7818fbfe76e375 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 13:36:13 -0800 Subject: [PATCH 06/18] Cleanup Signed-off-by: John Mazanec --- .../java/org/opensearch/knn/ODFERestTestCase.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java index 4fc4cc8be0..efdde63c58 100644 --- a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java @@ -86,10 +86,6 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE return builder.build(); } - protected boolean preserveClusterUponCompletion() { - return false; - } - protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { // Similar to client configuration with OpenSearch: // https://github.com/opensearch-project/OpenSearch/blob/2.11.1/test/framework/src/main/java/org/opensearch/test/rest/OpenSearchRestTestCase.java#L841-L863 From 85a4c09aefecb755df091a0897261660a24ce8a6 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 13:37:17 -0800 Subject: [PATCH 07/18] Move derived to default as false Signed-off-by: John Mazanec --- src/main/java/org/opensearch/knn/index/KNNSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index b2f7b2f4c9..8d58b8c887 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -272,7 +272,7 @@ public class KNNSettings { public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( KNN_DERIVED_SOURCE_ENABLED, - true, + false, IndexScope, Setting.Property.Final ); From 939cec1216ad60c9d5998d538568ee365b1310cc Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 13:39:46 -0800 Subject: [PATCH 08/18] Fix up feature flag Signed-off-by: John Mazanec --- .../KNN9120Codec/DerivedSourceStoredFieldsFormat.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index cf242b0925..095222b54c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -46,6 +46,10 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI derivedVectorFields.add(fieldInfo); } } + // If no fields have it enabled, + if (derivedVectorFields.isEmpty()) { + return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); + } DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( derivedSourceReadersSupplier, new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext), @@ -67,7 +71,9 @@ public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentI vectorFieldTypes.add(fieldType.name()); } } - return new DerivedSourceStoredFieldsWriter(delegateWriter, vectorFieldTypes); + if (vectorFieldTypes.isEmpty() == false) { + return new DerivedSourceStoredFieldsWriter(delegateWriter, vectorFieldTypes); + } } return delegateWriter; } From 4bc2828b2bf707e4b6cac220a3955b43bd070993 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 14:02:42 -0800 Subject: [PATCH 09/18] Fix issue with fieldtype for lucene Signed-off-by: John Mazanec --- .../java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 2abbf182d1..49cd02d5b0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -129,7 +129,9 @@ private LuceneFieldMapper( } if (isDerivedSourceEnabled) { + this.fieldType = new FieldType(this.fieldType); this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + this.fieldType.freeze(); } KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() From 74387f9851b248eb2881f96fe23e7ac0945749aa Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 20:18:29 -0800 Subject: [PATCH 10/18] Address initial comments Signed-off-by: John Mazanec --- .../DerivedSourceStoredFieldsFormat.java | 20 ++--- .../DerivedSourceStoredFieldsReader.java | 86 +++++++++++++++---- .../DerivedSourceStoredFieldsWriter.java | 4 - .../derivedsource/DerivedSourceReaders.java | 11 ++- .../DerivedSourceVectorInjector.java | 14 ++- .../NestedPerFieldDerivedVectorInjector.java | 2 +- .../PerFieldDerivedVectorInjector.java | 2 +- .../RootPerFieldDerivedVectorInjector.java | 2 +- 8 files changed, 104 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index 095222b54c..55d8868dc1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -19,7 +19,6 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; -import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.io.IOException; @@ -40,24 +39,25 @@ public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { @Override public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) throws IOException { - List derivedVectorFields = new ArrayList<>(); + List derivedVectorFields = null; for (FieldInfo fieldInfo : fieldInfos) { if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) { + // Lazily initialize the list of fields + if (derivedVectorFields == null) { + derivedVectorFields = new ArrayList<>(); + } derivedVectorFields.add(fieldInfo); } } - // If no fields have it enabled, - if (derivedVectorFields.isEmpty()) { + // If no fields have it enabled, we can just short-circuit and return the delegate's fieldReader + if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } - DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( - derivedSourceReadersSupplier, - new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext), - derivedVectorFields - ); return new DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), - derivedSourceVectorInjector + derivedVectorFields, + derivedSourceReadersSupplier, + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index ef9eba126d..6c1ade140b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -5,24 +5,63 @@ package org.opensearch.knn.index.codec.KNN9120Codec; -import lombok.RequiredArgsConstructor; -import lombok.Setter; import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.StoredFieldVisitor; +import org.apache.lucene.util.IOUtils; import org.opensearch.index.fieldvisitor.FieldsVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; import java.io.IOException; +import java.util.List; -@RequiredArgsConstructor public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; - // Given docId and source, process source + private final List derivedVectorFields; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final SegmentReadState segmentReadState; + private final boolean shouldInject; + private final DerivedSourceVectorInjector derivedSourceVectorInjector; - @Setter - private boolean shouldInject = true; + /** + * + * @param delegate delegate StoredFieldsReader + * @param derivedVectorFields List of fields that are derived source fields + * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param segmentReadState SegmentReadState for the segment + * @throws IOException in case of I/O error + */ + public DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState + ) throws IOException { + this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + } + + private DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + boolean shouldInject + ) throws IOException { + this.delegate = delegate; + this.derivedVectorFields = derivedVectorFields; + this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.segmentReadState = segmentReadState; + this.shouldInject = shouldInject; + this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); + } + + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { + return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + } @Override public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { @@ -43,7 +82,17 @@ public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IO @Override public StoredFieldsReader clone() { - return new DerivedSourceStoredFieldsReader(delegate.clone(), derivedSourceVectorInjector); + try { + return new DerivedSourceStoredFieldsReader( + delegate.clone(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + shouldInject + ); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override @@ -53,22 +102,27 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { - delegate.close(); + IOUtils.close(delegate, derivedSourceVectorInjector); } /** * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, * on merge we will end up just writing the source to disk * - * @param storedFieldsReader stored fields reader to wrap - * @return wrapped stored fields reader + * @return Merged instance that wont inject by default */ - public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) { - if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) { - StoredFieldsReader storedFieldsReaderClone = storedFieldsReader.clone(); - ((DerivedSourceStoredFieldsReader) storedFieldsReaderClone).setShouldInject(false); - return storedFieldsReaderClone; + @Override + public StoredFieldsReader getMergeInstance() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.getMergeInstance(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + false + ); + } catch (IOException e) { + throw new RuntimeException(e); } - return storedFieldsReader; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index 1b3c8b3b12..b01da60011 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -65,10 +65,6 @@ public void writeField(FieldInfo info, DataInput value, int length) throws IOExc @Override public int merge(MergeState mergeState) throws IOException { - // We have to wrap these here to avoid storing the vectors during merge - for (int i = 0; i < mergeState.storedFieldsReaders.length; i++) { - mergeState.storedFieldsReaders[i] = DerivedSourceStoredFieldsReader.wrapForMerge(mergeState.storedFieldsReaders[i]); - } return delegate.merge(mergeState); } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 5bdcc5181d..1b3cdb3f85 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -10,14 +10,23 @@ import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.util.IOUtils; + +import java.io.Closeable; +import java.io.IOException; /** * Class holds the readers necessary to implement derived source. */ @RequiredArgsConstructor @Getter -public class DerivedSourceReaders { +public class DerivedSourceReaders implements Closeable { private final KnnVectorsReader knnVectorsReader; private final DocValuesProducer docValuesProducer; private final FieldsProducer fieldsProducer; + + @Override + public void close() throws IOException { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index 7218735ec0..c59bd43792 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -8,6 +8,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.util.IOUtils; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentHelper; @@ -17,6 +18,7 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.Closeable; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -31,8 +33,9 @@ * format readers and information about the fields to inject vectors into the source. */ @Log4j2 -public class DerivedSourceVectorInjector { +public class DerivedSourceVectorInjector implements Closeable { + private final DerivedSourceReaders derivedSourceReaders; private final List perFieldDerivedVectorInjectors; private final Set fieldNames; @@ -48,7 +51,7 @@ public DerivedSourceVectorInjector( SegmentReadState segmentReadState, List fieldsToInjectVector ) throws IOException { - DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); this.perFieldDerivedVectorInjectors = new ArrayList<>(); this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { @@ -67,7 +70,7 @@ public DerivedSourceVectorInjector( * @return byte array of the source with the vector fields added * @throws IOException if there is an issue reading from the formats */ - public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOException { + public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { // Deserialize the source into a modifiable map Tuple> mapTuple = XContentHelper.convertToMap( BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), @@ -121,4 +124,9 @@ public boolean shouldInject(String[] includes, String[] excludes) { } return true; } + + @Override + public void close() throws IOException { + IOUtils.close(derivedSourceReaders); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index 6dce3f12ae..4f9f866e88 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -35,7 +35,7 @@ public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVecto private final SegmentReadState segmentReadState; @Override - public void inject(Integer parentDocId, Map sourceAsMap) throws IOException { + public void inject(int parentDocId, Map sourceAsMap) throws IOException { // Setup the iterator. Return if not-relevant String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java index 2467c26108..b0bc5930c3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java @@ -21,5 +21,5 @@ public interface PerFieldDerivedVectorInjector { * @param sourceAsMap Source as map * @throws IOException if there is an issue reading from the formats */ - void inject(Integer docId, Map sourceAsMap) throws IOException; + void inject(int docId, Map sourceAsMap) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java index 4812dca3c6..b46744a51b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -37,7 +37,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade } @Override - public void inject(Integer docId, Map sourceAsMap) throws IOException { + public void inject(int docId, Map sourceAsMap) throws IOException { KNNVectorValues vectorValues = vectorValuesSupplier.get(); if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { sourceAsMap.put(fieldInfo.name, vectorValues.getVector()); From d449151d108bad6775584d7e79a40f26766fac45 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 05:20:59 -0800 Subject: [PATCH 11/18] Add partial support for object mappings Signed-off-by: John Mazanec --- .../codec/KNN9120Codec/KNN9120Codec.java | 28 ++- .../derivedsource/DerivedSourceReaders.java | 7 +- .../DerivedSourceReadersSupplier.java | 5 +- .../NestedPerFieldDerivedVectorInjector.java | 149 +++++++++++--- .../NestedPerFieldParentToDocIdIterator.java | 3 + .../derivedsource/ParentChildHelper.java | 11 + .../vectorvalues/KNNVectorValuesFactory.java | 6 +- .../opensearch/knn/integ/DerivedSourceIT.java | 192 +++++++++++++++++- .../org/opensearch/knn/KNNRestTestCase.java | 85 ++++++++ 9 files changed, 440 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index e0d9b678b5..b8a5e6a121 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -67,11 +67,29 @@ public KnnVectorsFormat knnVectorsFormat() { @Override public StoredFieldsFormat storedFieldsFormat() { - DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier( - (segmentReadState) -> knnVectorsFormat().fieldsReader(segmentReadState), - (segmentReadState) -> docValuesFormat().fieldsProducer(segmentReadState), - (segmentReadState) -> postingsFormat().fieldsProducer(segmentReadState) - ); + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { + if (segmentReadState.fieldInfos.hasVectorValues()) { + return knnVectorsFormat().fieldsReader(segmentReadState); + } + return null; + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasDocValues()) { + return docValuesFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasPostings()) { + return postingsFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState -> { + if (segmentReadState.fieldInfos.hasNorms()) { + return normsFormat().normsProducer(segmentReadState); + } + return null; + })); return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 1b3cdb3f85..3f2d418fb3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -10,13 +10,15 @@ import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; import org.apache.lucene.util.IOUtils; import java.io.Closeable; import java.io.IOException; /** - * Class holds the readers necessary to implement derived source. + * Class holds the readers necessary to implement derived source. Important to note that if a segment does not have + * any of these fields, the values will be null. Caller needs to check if these are null before using. */ @RequiredArgsConstructor @Getter @@ -24,9 +26,10 @@ public class DerivedSourceReaders implements Closeable { private final KnnVectorsReader knnVectorsReader; private final DocValuesProducer docValuesProducer; private final FieldsProducer fieldsProducer; + private final NormsProducer normsProducer; @Override public void close() throws IOException { - IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer); + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java index 8e46952c2f..2dafa3af94 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -9,6 +9,7 @@ import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; import org.apache.lucene.index.SegmentReadState; import java.io.IOException; @@ -23,6 +24,7 @@ public class DerivedSourceReadersSupplier { private final DerivedSourceReaderSupplier knnVectorsReaderSupplier; private final DerivedSourceReaderSupplier docValuesProducerSupplier; private final DerivedSourceReaderSupplier fieldsProducerSupplier; + private final DerivedSourceReaderSupplier normsProducer; /** * Get the readers for the segment @@ -35,7 +37,8 @@ public DerivedSourceReaders getReaders(SegmentReadState state) throws IOExceptio return new DerivedSourceReaders( knnVectorsReaderSupplier.apply(state), docValuesProducerSupplier.apply(state), - fieldsProducerSupplier.apply(state) + fieldsProducerSupplier.apply(state), + normsProducer.apply(state) ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index 4f9f866e88..c85aacd9c1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.SegmentReadState; @@ -36,7 +37,19 @@ public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVecto @Override public void inject(int parentDocId, Map sourceAsMap) throws IOException { - // Setup the iterator. Return if not-relevant + // If the parent has the field, then it is just an object field. + if (getLowestDocIdForField(childFieldInfo.name, parentDocId) == parentDocId) { + injectObject(parentDocId, sourceAsMap); + return; + } + + if (ParentChildHelper.splitPath(childFieldInfo.name).length > 2) { + // We do not support nested fields beyond one level + log.warn("Nested fields beyond one level are not supported. Field: {}", childFieldInfo.name); + return; + } + + // Setup the iterator. Return if no parent String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); if (parentFieldName == null) { @@ -49,7 +62,7 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx parentDocId ); - // Initializes the parent field so that there is a map to put each of the children + // Initializes the parent field so that there is a list to put each of the children Object originalParentValue = sourceAsMap.get(parentFieldName); List> reconstructedSource; if (originalParentValue instanceof Map) { @@ -58,8 +71,9 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx reconstructedSource = (List>) originalParentValue; } - // Contains the positions of existing objects in the map. This is used to help figure out the best play to put back the vectors - List positions = mapObjectsToPositionInSource( + // Contains the docIds of existing objects in the map in order. This is used to help figure out the best play + // to put back the vectors + List positions = mapObjectsToPositionInNestedList( reconstructedSource, nestedPerFieldParentToDocIdIterator.firstChild(), parentDocId @@ -67,7 +81,7 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx // Finally, inject children for the document into the source. This code is non-trivial because filtering out // the vectors during write could mean that children docs disappear from the source. So, to properly put - // everything back, we need to igure out where the existing fields in the original map to + // everything back, we need to figure out where the existing fields in the original map to KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( childFieldInfo, derivedSourceReaders.getDocValuesProducer(), @@ -83,10 +97,8 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx if (vectorValues.docId() != nestedPerFieldParentToDocIdIterator.childId()) { continue; } - int docId = vectorValues.docId(); - if (docId >= parentDocId) { - break; - } + + int docId = nestedPerFieldParentToDocIdIterator.childId(); boolean isInsert = true; int position = positions.size(); // by default we insert it at the end for (int i = offsetPositionsIndex; i < positions.size(); i++) { @@ -111,11 +123,40 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx sourceAsMap.put(parentFieldName, reconstructedSource); } - private List mapObjectsToPositionInSource(List> originals, int firstChild, int parent) throws IOException { + private void injectObject(int docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + if (vectorValues.docId() != docId && vectorValues.advance(docId) != docId) { + return; + } + String[] fields = ParentChildHelper.splitPath(childFieldInfo.name); + Map currentMap = sourceAsMap; + for (int i = 0; i < fields.length - 1; i++) { + String field = fields[i]; + currentMap = (Map) currentMap.computeIfAbsent(field, k -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectorValues.getVector()); + } + + /** + * Given a list of maps, map each map to a position in the nested list. This is used to help figure out where to put + * the vectors back in the source. + * + * @param originals list of maps + * @param firstChild first child docId + * @param parent parent docId + * @return list of positions in the nested list + * @throws IOException if there is an issue reading from the formats + */ + private List mapObjectsToPositionInNestedList(List> originals, int firstChild, int parent) + throws IOException { List positions = new ArrayList<>(); int offset = firstChild; for (Map docWithFields : originals) { - int fieldMapping = docToOrdinal(docWithFields, offset, parent); + int fieldMapping = mapToDocId(docWithFields, offset, parent); assert fieldMapping != -1; positions.add(fieldMapping); offset = fieldMapping + 1; @@ -123,27 +164,69 @@ private List mapObjectsToPositionInSource(List> ori return positions; } - // Offset is first eligible object - private Integer docToOrdinal(Map doc, int offset, int parent) throws IOException { - String keyToCheck = doc.keySet().iterator().next(); - int position = getFieldsForDoc(keyToCheck, offset); - // Advancing past the parent means something went horribly wrong + /** + * Given a doc as a map and the offset it has to be, find the ordinal of the first field that is greater than the + * offset. + * + * @param doc doc to find the ordinal for + * @param offset offset to start searching from + * @return id of the first field that is greater than the offset + * @throws IOException if there is an issue reading from the formats + */ + private int mapToDocId(Map doc, int offset, int parent) throws IOException { + // For all the fields, we look for the first doc that matches any of the fields. + int position = NO_MORE_DOCS; + for (String key : doc.keySet()) { + position = getLowestDocIdForField(ParentChildHelper.constructSiblingField(childFieldInfo.name, key), offset); + if (position < parent) { + break; + } + } + + // Advancing past the parent means something went wrong assert position < parent; return position; } - private int getFieldsForDoc(String fieldToMatch, int offset) throws IOException { - // TODO: Fix this up to follow - // https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java#L170-L218. - // In a perfect world, it would try everything and fall through to the field exists stuff - FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo( - ParentChildHelper.constructSiblingField(childFieldInfo.name, fieldToMatch) - ); + /** + * Get the lowest docId for a field that is greater than the offset. + * + * @param fieldToMatch field to find the lowest docId for + * @param offset offset to start searching from + * @return lowest docId for the field that is greater than the offset. Returns {@link DocIdSetIterator#NO_MORE_DOCS} if doc cannot be found + * @throws IOException if there is an issue reading from the formats + */ + private int getLowestDocIdForField(String fieldToMatch, int offset) throws IOException { + // This method implementation is inspired by the FieldExistsQuery in Lucene and the FieldNamesMapper in + // Opensearch. We first mimic the logic in the FieldExistsQuery in order to identify the docId of the nested + // doc. If that fails, we rely on + // References: + // 1. https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java#L170-L218. + // 2. + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/mapper/FieldMapper.java#L316-L324 + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(fieldToMatch); + + if (fieldInfo == null) { + return NO_MORE_DOCS; + } + DocIdSetIterator iterator = null; - if (fieldInfo != null) { - switch (fieldInfo.getDocValuesType()) { - case NONE: + if (fieldInfo.hasNorms() && derivedSourceReaders.getNormsProducer() != null) { // the field indexes norms + iterator = derivedSourceReaders.getNormsProducer().getNorms(fieldInfo); + } else if (fieldInfo.getVectorDimension() != 0 && derivedSourceReaders.getKnnVectorsReader() != null) { // the field indexes vectors + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + iterator = derivedSourceReaders.getKnnVectorsReader().getFloatVectorValues(fieldInfo.name); + break; + case BYTE: + iterator = derivedSourceReaders.getKnnVectorsReader().getByteVectorValues(fieldInfo.name); break; + } + } else if (fieldInfo.getDocValuesType() != DocValuesType.NONE && derivedSourceReaders.getDocValuesProducer() != null) { // the field + // indexes + // doc + // values + switch (fieldInfo.getDocValuesType()) { case NUMERIC: iterator = derivedSourceReaders.getDocValuesProducer().getNumeric(fieldInfo); break; @@ -159,6 +242,7 @@ private int getFieldsForDoc(String fieldToMatch, int offset) throws IOException case SORTED_SET: iterator = derivedSourceReaders.getDocValuesProducer().getSortedSet(fieldInfo); break; + case NONE: default: throw new AssertionError(); } @@ -167,9 +251,16 @@ private int getFieldsForDoc(String fieldToMatch, int offset) throws IOException return iterator.advance(offset); } + // Check the field names field type for matches + if (derivedSourceReaders.getFieldsProducer() == null) { + return NO_MORE_DOCS; + } Terms terms = derivedSourceReaders.getFieldsProducer().terms(FieldNamesFieldMapper.NAME); + if (terms == null) { + return NO_MORE_DOCS; + } TermsEnum fieldNameFieldsTerms = terms.iterator(); - BytesRef fieldToMatchRef = new BytesRef(fieldToMatch); + BytesRef fieldToMatchRef = new BytesRef(fieldInfo.name); PostingsEnum postingsEnum = null; while (fieldNameFieldsTerms.next() != null) { BytesRef currentTerm = fieldNameFieldsTerms.term(); @@ -178,7 +269,9 @@ private int getFieldsForDoc(String fieldToMatch, int offset) throws IOException break; } } - assert postingsEnum != null; + if (postingsEnum == null) { + return NO_MORE_DOCS; + } return postingsEnum.advance(offset); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java index 2d1dffa7fc..d6d4e50621 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -128,6 +128,9 @@ private List getChildren() throws IOException { String parentField = ParentChildHelper.getParentField(childField); Terms terms = derivedSourceReaders.getFieldsProducer().terms("_nested_path"); + if (terms == null) { + return Collections.emptyList(); + } TermsEnum nestedFieldsTerms = terms.iterator(); BytesRef childPathRef = new BytesRef(parentField); PostingsEnum postingsEnum = null; diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java index f47fbc9d09..534cf93d7d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -48,4 +48,15 @@ public static String getChildField(String field) { public static String constructSiblingField(String field, String sibling) { return getParentField(field) + "." + sibling; } + + /** + * Split a nested field path into an array of strings. For instance, if the field is "parent.to.child", this would + * return ["parent", "to", "child"]. + * + * @param field nested field path + * @return array of strings representing the nested field path + */ + public static String[] splitPath(String field) { + return field.split("\\."); + } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 699d62843b..9ae7f3842a 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -88,7 +88,7 @@ public static KNNVectorValues getVectorValues( final KnnVectorsReader knnVectorsReader ) throws IOException { final DocIdSetIterator docIdSetIterator; - if (fieldInfo.hasVectorValues()) { + if (fieldInfo.hasVectorValues() && knnVectorsReader != null) { if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { docIdSetIterator = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { @@ -96,8 +96,10 @@ public static KNNVectorValues getVectorValues( } else { throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); } - } else { + } else if (docValuesProducer != null) { docIdSetIterator = docValuesProducer.getBinary(fieldInfo); + } else { + throw new IllegalArgumentException("Field does not have vector values and DocValues"); } final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 424c05c842..0e1fb731fc 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -440,15 +440,191 @@ public void testNestedMultiDocBasic() { testDerivedSourceE2E(indexConfigContexts); } + /** + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testObjectFieldTypes() { + String PATH_1_NAME = "path_1"; + String PATH_2_NAME = "path_2"; + + String objectFieldTypeMapping = XContentFactory.jsonBuilder() + .startObject() // 1-open + .startObject(PROPERTIES_FIELD) // 2-open + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + + .startObject(PATH_1_NAME) + .startObject(PROPERTIES_FIELD) + + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(PATH_2_NAME) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "3") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + // TODO Test configurations - // 1. Baseline flat - // 2. Multi-field flat - // 3. Nested single doc - // 4. Nested multi-docs - // 5. Nested multi-fields multi docs - // 6. All types of fields - // 7. FLS index - // 8. Object fields + // 1. Object fields + // 2. FLS index // We need to write a single method that will run through all the different possible combinations and // abstact when necessary. diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index b134802910..bd4858bb5a 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1437,6 +1437,91 @@ public void bulkIngestRandomVectorsWithSkipsAndMultFields( } } + @SneakyThrows + public void bulkIngestRandomVectorsMultiFieldsWithSkips( + String indexName, + List vectorFields, + List textFields, + int numVectors, + int dimension, + float skipProb + ) { + List vectors = new ArrayList<>(); + int seed = 1; + for (String ignored : vectorFields) { + vectors.add(TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, seed++)); + } + + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + + List includeVectorFields = new ArrayList<>(); + for (String ignored : vectorFields) { + includeVectorFields.add(random.nextFloat() > skipProb); + } + List includeTextFields = new ArrayList<>(); + for (String ignored : textFields) { + includeTextFields.add(random.nextFloat() > skipProb); + } + + // If all are skipped, just add a random field + if (includeVectorFields.stream().allMatch((t) -> !t) && includeTextFields.stream().allMatch((t) -> !t)) { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } else { + Map source = new HashMap<>(); + for (int j = 0; j < includeVectorFields.size(); j++) { + if (includeVectorFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(vectorFields.get(j)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + log.info("Value: " + value); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectors.get(j)[i]); + } + } + for (int j = 0; j < includeTextFields.size(); j++) { + if (includeTextFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(textFields.get(j)); + log.info("Fields: " + Arrays.toString(fields)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + log.info("FUll path: " + textFields.get(j)); + log.info("Key: " + field); + log.info("Value: " + value); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], "test-test"); + } + } + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + mapToBuilder(builder, source); + builder.endObject(); + log.info(builder.toString()); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + } + + @SneakyThrows + void mapToBuilder(XContentBuilder xContentBuilder, Map source) { + for (Map.Entry entry : source.entrySet()) { + if (entry.getValue() instanceof Map) { + xContentBuilder.startObject(entry.getKey()); + mapToBuilder(xContentBuilder, (Map) entry.getValue()); + xContentBuilder.endObject(); + } else { + xContentBuilder.field(entry.getKey(), entry.getValue()); + } + } + } + public void bulkIngestRandomVectorsWithSkipsAndNested( String indexName, String nestedFieldName, From c3483dfa805f7e31b422720c5eaec00337a2e4de Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 05:23:15 -0800 Subject: [PATCH 12/18] Ignore its for now Signed-off-by: John Mazanec --- src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 0e1fb731fc..4b569a41ad 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -9,6 +9,7 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; +import org.junit.Ignore; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; @@ -27,8 +28,9 @@ /** * Integration tests for derived source feature for vector fields. Currently, with derived source, there are - * a few gaps in functionality. + * a few gaps in functionality. Ignoring tests for now as feature is experimental. */ +@Ignore public class DerivedSourceIT extends KNNRestTestCase { private final static String NESTED_NAME = "test_nested"; From 399d281b00325d0328002bdfc1d098809b487ce8 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 08:53:16 -0800 Subject: [PATCH 13/18] Fix issues with object type Signed-off-by: John Mazanec --- .../NestedPerFieldDerivedVectorInjector.java | 13 +- .../NestedPerFieldParentToDocIdIterator.java | 12 + .../opensearch/knn/integ/DerivedSourceIT.java | 246 ++++++++++++++++-- .../org/opensearch/knn/KNNRestTestCase.java | 7 - 4 files changed, 237 insertions(+), 41 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index c85aacd9c1..016b4153aa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -38,17 +38,12 @@ public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVecto @Override public void inject(int parentDocId, Map sourceAsMap) throws IOException { // If the parent has the field, then it is just an object field. - if (getLowestDocIdForField(childFieldInfo.name, parentDocId) == parentDocId) { + int lowestDocIdForFieldWithParentAsOffset = getLowestDocIdForField(childFieldInfo.name, parentDocId); + if (lowestDocIdForFieldWithParentAsOffset == parentDocId) { injectObject(parentDocId, sourceAsMap); return; } - if (ParentChildHelper.splitPath(childFieldInfo.name).length > 2) { - // We do not support nested fields beyond one level - log.warn("Nested fields beyond one level are not supported. Field: {}", childFieldInfo.name); - return; - } - // Setup the iterator. Return if no parent String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); @@ -62,6 +57,10 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx parentDocId ); + if (nestedPerFieldParentToDocIdIterator.numChildren() == 0) { + return; + } + // Initializes the parent field so that there is a list to put each of the children Object originalParentValue = sourceAsMap.get(parentFieldName); List> reconstructedSource; diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java index d6d4e50621..d2bc1a32fd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -91,6 +91,14 @@ public int childId() { return children.get(currentChild); } + /** + * + * @return the number of children for this parent + */ + public int numChildren() { + return children.size(); + } + /** * For parentDocId of this class, find the one just before it to be used for matching children. * @@ -122,6 +130,10 @@ private int previousParent() throws IOException { * @throws IOException if there is an error reading the children */ private List getChildren() throws IOException { + if (this.parentDocId - this.previousParentDocId <= 1) { + return Collections.emptyList(); + } + // First, we need to get the currect PostingsEnum for the key as _nested_path and the value the actual parent // path. String childField = childFieldInfo.name; diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 4b569a41ad..2607a9695f 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -9,7 +9,6 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; -import org.junit.Ignore; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; @@ -30,7 +29,7 @@ * Integration tests for derived source feature for vector fields. Currently, with derived source, there are * a few gaps in functionality. Ignoring tests for now as feature is experimental. */ -@Ignore +// @Ignore public class DerivedSourceIT extends KNNRestTestCase { private final static String NESTED_NAME = "test_nested"; @@ -52,10 +51,14 @@ public class DerivedSourceIT extends KNNRestTestCase { .build(); /** - * Testing flat, single field base case with index configuration: + * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for + * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an + * index without derived source enabled (baseline). + * Test mapping: * { * "settings": { * "index.knn" true, + * "index.knn.derived_source.enabled": true * }, * "mappings":{ * "properties": { @@ -66,10 +69,11 @@ public class DerivedSourceIT extends KNNRestTestCase { * } * } * } - * Comparing to the baseline: + * Baseline mapping: * { * "settings": { * "index.knn" true, + * "index.knn.derived_source.enabled": true * }, * "mappings":{ * "properties": { @@ -155,6 +159,53 @@ public void testFlatBaseCase() { testDerivedSourceE2E(indexConfigContexts); } + /** + * Testing multiple flat fields. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + */ @SneakyThrows public void testMultiFlatFields() { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -263,6 +314,55 @@ public void testMultiFlatFields() { testDerivedSourceE2E(indexConfigContexts); } + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ public void testNestedSingleDocBasic() { String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); List indexConfigContexts = List.of( @@ -351,6 +451,56 @@ public void testNestedSingleDocBasic() { testDerivedSourceE2E(indexConfigContexts); } + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ @SneakyThrows public void testNestedMultiDocBasic() { String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); @@ -443,28 +593,71 @@ public void testNestedMultiDocBasic() { } /** + * Test object (non-nested field) + * Test * { - * "properties": { - * "vector_field_1" : { - * "type" : "knn_vector", - * "dimension" : 2 - * }, - * "path_1": { - * "properties" : { - * "vector_field_2" : { - * "type" : "knn_vector", - * "dimension" : 2 - * }, - * "path_2": { - * "properties" : { - * "vector_field_3" : { - * "type" : "knn_vector", - * "dimension" : 2 - * }, - * } - * } - * } - * } + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + * } + * } + * Baseline + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } * } * } */ @@ -625,8 +818,7 @@ public void testObjectFieldTypes() { } // TODO Test configurations - // 1. Object fields - // 2. FLS index + // 1. Ensure compatibility with FLS // We need to write a single method that will run through all the different possible combinations and // abstact when necessary. diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index bd4858bb5a..e95deb558d 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1477,7 +1477,6 @@ public void bulkIngestRandomVectorsMultiFieldsWithSkips( for (int k = 0; k < fields.length - 1; k++) { String field = fields[k]; Object value = currentMap.get(field); - log.info("Value: " + value); currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); } currentMap.put(fields[fields.length - 1], vectors.get(j)[i]); @@ -1486,14 +1485,10 @@ public void bulkIngestRandomVectorsMultiFieldsWithSkips( for (int j = 0; j < includeTextFields.size(); j++) { if (includeTextFields.get(j)) { String[] fields = ParentChildHelper.splitPath(textFields.get(j)); - log.info("Fields: " + Arrays.toString(fields)); Map currentMap = source; for (int k = 0; k < fields.length - 1; k++) { String field = fields[k]; Object value = currentMap.get(field); - log.info("FUll path: " + textFields.get(j)); - log.info("Key: " + field); - log.info("Value: " + value); currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); } currentMap.put(fields[fields.length - 1], "test-test"); @@ -1503,7 +1498,6 @@ public void bulkIngestRandomVectorsMultiFieldsWithSkips( XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); mapToBuilder(builder, source); builder.endObject(); - log.info(builder.toString()); addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); } } @@ -1569,7 +1563,6 @@ public void bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( } builder.endArray(); builder.endObject(); - // log.info(builder.toString()); addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); } } From 14d3ead20e18cf99ae3429fbd235ed959f16c34a Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 09:31:35 -0800 Subject: [PATCH 14/18] Remove conditional clone vector Signed-off-by: John Mazanec --- .../derivedsource/NestedPerFieldDerivedVectorInjector.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index 016b4153aa..73dc97483d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx reconstructedSource.add(position, new HashMap<>()); positions.add(position, docId); } - reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); + reconstructedSource.get(position).put(childFieldName, vectorValues.getVector()); offsetPositionsIndex = position + 1; } sourceAsMap.put(parentFieldName, reconstructedSource); From 1fe730284898f61bc299bbd426af30c30e8bbe64 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 10:31:04 -0800 Subject: [PATCH 15/18] Conditionally clone vectors Signed-off-by: John Mazanec --- .../derivedsource/NestedPerFieldDerivedVectorInjector.java | 2 +- .../codec/derivedsource/RootPerFieldDerivedVectorInjector.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index 73dc97483d..016b4153aa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx reconstructedSource.add(position, new HashMap<>()); positions.add(position, docId); } - reconstructedSource.get(position).put(childFieldName, vectorValues.getVector()); + reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); offsetPositionsIndex = position + 1; } sourceAsMap.put(parentFieldName, reconstructedSource); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java index b46744a51b..57b02f8577 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -40,7 +40,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade public void inject(int docId, Map sourceAsMap) throws IOException { KNNVectorValues vectorValues = vectorValuesSupplier.get(); if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { - sourceAsMap.put(fieldInfo.name, vectorValues.getVector()); + sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector()); } } } From b3aba03fb11402307079be1764d086fd00753dbc Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 11:48:11 -0800 Subject: [PATCH 16/18] Make setting completely unmodifiable Signed-off-by: John Mazanec --- src/main/java/org/opensearch/knn/index/KNNSettings.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 8d58b8c887..34ab65b61a 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -274,7 +274,8 @@ public class KNNSettings { KNN_DERIVED_SOURCE_ENABLED, false, IndexScope, - Setting.Property.Final + Final, + UnmodifiableOnRestore ); /** From fd32a12576d44fdfae4103b31ca07995486be30d Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 13:45:06 -0800 Subject: [PATCH 17/18] Update based on feedback Signed-off-by: John Mazanec --- build.gradle | 2 + .../DerivedSourceVectorInjector.java | 1 + .../RootPerFieldDerivedVectorInjector.java | 2 +- .../opensearch/knn/integ/DerivedSourceIT.java | 123 ++++++++++++++---- .../org/opensearch/knn/KNNRestTestCase.java | 20 ++- 5 files changed, 119 insertions(+), 29 deletions(-) diff --git a/build.gradle b/build.gradle index 76c6d774df..17552cf97d 100644 --- a/build.gradle +++ b/build.gradle @@ -390,6 +390,7 @@ integTest { systemProperty("https", is_https) systemProperty("user", user) systemProperty("password", password) + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -451,6 +452,7 @@ task integTestRemote(type: RestIntegTestTask) { systemProperty 'cluster.number_of_nodes', "${_numNodes}" systemProperty 'tests.security.manager', 'false' + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) // Run tests with remote cluster only if rest case is defined if (System.getProperty("tests.rest.cluster") != null) { diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index c59bd43792..b2788f4053 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -71,6 +71,7 @@ public DerivedSourceVectorInjector( * @throws IOException if there is an issue reading from the formats */ public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { + // TODO: Add link to core code // Deserialize the source into a modifiable map Tuple> mapTuple = XContentHelper.convertToMap( BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java index 57b02f8577..430fd24ae1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -16,7 +16,7 @@ /** * {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields). */ -public class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { +class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { private final FieldInfo fieldInfo; private final CheckedSupplier, IOException> vectorValuesSupplier; diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index 2607a9695f..08c1919601 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -9,10 +9,16 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.KNNSettings; @@ -29,7 +35,6 @@ * Integration tests for derived source feature for vector fields. Currently, with derived source, there are * a few gaps in functionality. Ignoring tests for now as feature is experimental. */ -// @Ignore public class DerivedSourceIT extends KNNRestTestCase { private final static String NESTED_NAME = "test_nested"; @@ -73,7 +78,7 @@ public class DerivedSourceIT extends KNNRestTestCase { * { * "settings": { * "index.knn" true, - * "index.knn.derived_source.enabled": true + * "index.knn.derived_source.enabled": false * }, * "mappings":{ * "properties": { @@ -473,7 +478,6 @@ public void testNestedSingleDocBasic() { * }, * } * } - * * } * } * } @@ -817,11 +821,12 @@ public void testObjectFieldTypes() { testDerivedSourceE2E(indexConfigContexts); } - // TODO Test configurations - // 1. Ensure compatibility with FLS - - // We need to write a single method that will run through all the different possible combinations and - // abstact when necessary. + /** + * Single method for running end to end tests for different index configurations for derived source. In general, + * flow of operations are + * + * @param indexConfigContexts {@link IndexConfigContext} + */ @SneakyThrows private void testDerivedSourceE2E(List indexConfigContexts) { // Make sure there are 6 @@ -833,9 +838,8 @@ private void testDerivedSourceE2E(List indexConfigContexts) // Merging testMerging(indexConfigContexts); - // Update - // TODO: Skipping nested for now - if (indexConfigContexts.get(0).isNested == false) { + // Update. Skipping update tests for nested docs for now. Will add in the future. + if (indexConfigContexts.get(0).isNested() == false) { testUpdate(indexConfigContexts); } @@ -932,6 +936,7 @@ private void testUpdate(List indexConfigContexts) { originalIndexNameDerivedSourceEnabled ); + // Sets the doc to an empty doc setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); refreshAllIndices(); @@ -990,6 +995,79 @@ private void testUpdate(List indexConfigContexts) { ); } + @SneakyThrows + private void testSearch(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + + // Default - all fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, true, null, null); + + // Default - no fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, false, null, null); + + // Exclude all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + null, + derivedSourceEnabledContext.vectorFieldNames + ); + + // Include all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + derivedSourceEnabledContext.vectorFieldNames, + null + ); + } + + @SneakyThrows + private void validateSearch(String indexName, int size, boolean isSourceEnabled, List includes, List excludes) { + // TODO: We need to figure out a way to enhance validation + QueryBuilder qb = new MatchAllQueryBuilder(); + Request request = new Request("POST", "/" + indexName + "/_search"); + + request.addParameter("size", Integer.toString(size)); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("query", qb); + if (isSourceEnabled == false) { + builder.field("_source", false); + } + if (includes != null) { + builder.startObject("_source"); + builder.startArray("includes"); + for (String include : includes) { + builder.value(include); + } + builder.endArray(); + builder.endObject(); + } + if (excludes != null) { + builder.startObject("_source"); + builder.startArray("excludes"); + for (String exclude : excludes) { + builder.value(exclude); + } + builder.endArray(); + builder.endObject(); + } + + builder.endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + List hits = parseSearchResponseHits(responseBody); + + assertNotEquals(0, hits.size()); + } + @SneakyThrows private void testDelete(List indexConfigContexts) { int docToDelete = 8; @@ -1021,12 +1099,6 @@ private void testDelete(List indexConfigContexts) { ); } - @SneakyThrows - private void testSearch(List indexConfigContexts) { - // TODO - - } - @SneakyThrows private void testReindex(List indexConfigContexts) { IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); @@ -1117,12 +1189,17 @@ private static class IndexConfigContext { @SneakyThrows private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { - int expectedSmaller = indexSizeInBytes(expectedSmallerIndex); - int expectedBigger = indexSizeInBytes(expectedBiggerIndex); - assertTrue( - "Expected smaller index " + expectedSmaller + " was bigger than the expected bigger index:" + expectedBigger, - expectedSmaller < expectedBigger - ); + if (isExhaustive()) { + logger.info("Checking index bigger assertion because running in exhaustive mode"); + int expectedSmaller = indexSizeInBytes(expectedSmallerIndex); + int expectedBigger = indexSizeInBytes(expectedBiggerIndex); + assertTrue( + "Expected smaller index " + expectedSmaller + " was bigger than the expected bigger index:" + expectedBigger, + expectedSmaller < expectedBigger + ); + } else { + logger.info("Skipping index bigger assertion because not running in exhaustive mode"); + } } private void assertDocsMatch(int docCount, String index1, String index2) { diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index e95deb558d..e206aa315b 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -173,6 +173,15 @@ public void cleanUpCache() throws Exception { clearCache(); } + /** + * Gives the ability for certain, more exhaustive checks, to be disabled by default + * + * @return If the test is running in exhaustive mode + */ + protected boolean isExhaustive() { + return Boolean.parseBoolean(System.getProperty("test.exhaustive", "false")); + } + /** * Create KNN Index with default settings */ @@ -276,6 +285,11 @@ protected Response performSearch(final String indexName, final String query, fin return response; } + protected List parseSearchResponseHits(String responseBody) throws IOException { + return (List) ((Map) createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map() + .get("hits")).get("hits"); + } + /** * Parse the response of KNN search into a List of KNNResults */ @@ -379,10 +393,6 @@ protected Double parseAggregationResponse(String responseBody, String aggregatio return Double.valueOf(String.valueOf(values.get("value"))); } - /** - * Parse the score from the KNN search response - */ - /** * Delete KNN index */ @@ -802,7 +812,7 @@ protected void updateKnnDoc(String index, String docId, String fieldName, Object } /** - * Update a KNN Doc with a new vector for the given fieldName + * Update a KNN Doc using the POST /\/_update/\. Only the vector field will be updated. */ protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); From 6462c4b476a122e2cf53e3a169fc686d7208070f Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 29 Jan 2025 14:20:19 -0800 Subject: [PATCH 18/18] Minor updates Signed-off-by: John Mazanec --- .../codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java | 2 ++ .../codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java | 2 ++ .../knn/index/codec/derivedsource/DerivedSourceReaders.java | 5 +++++ .../codec/derivedsource/DerivedSourceVectorInjector.java | 5 ++++- .../derivedsource/PerFieldDerivedVectorInjectorFactory.java | 2 +- 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index 55d8868dc1..e60b82b2e3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -15,6 +15,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; +import org.opensearch.common.Nullable; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; @@ -34,6 +35,7 @@ public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { private final StoredFieldsFormat delegate; private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; // IMPORTANT Do not rely on this for the reader, it will be null if SPI is used + @Nullable private final MapperService mapperService; @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index b01da60011..576bc9e987 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -72,6 +72,8 @@ public int merge(MergeState mergeState) throws IOException { public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOException { // Parse out the vectors from the source if (Objects.equals(fieldInfo.name, SourceFieldMapper.NAME) && !vectorFieldTypes.isEmpty()) { + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 Tuple> mapTuple = XContentHelper.convertToMap( BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), true, diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 3f2d418fb3..c7e472e601 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -12,6 +12,7 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.NormsProducer; import org.apache.lucene.util.IOUtils; +import org.opensearch.common.Nullable; import java.io.Closeable; import java.io.IOException; @@ -23,9 +24,13 @@ @RequiredArgsConstructor @Getter public class DerivedSourceReaders implements Closeable { + @Nullable private final KnnVectorsReader knnVectorsReader; + @Nullable private final DocValuesProducer docValuesProducer; + @Nullable private final FieldsProducer fieldsProducer; + @Nullable private final NormsProducer normsProducer; @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index b2788f4053..d3b1fe8469 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -71,7 +71,8 @@ public DerivedSourceVectorInjector( * @throws IOException if there is an issue reading from the formats */ public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { - // TODO: Add link to core code + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 // Deserialize the source into a modifiable map Tuple> mapTuple = XContentHelper.convertToMap( BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), @@ -88,6 +89,8 @@ public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException } // At this point, we can serialize the modified source map + // Setting to 1024 based on + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourcePhase.java#L106 BytesStreamOutput bStream = new BytesStreamOutput(1024); MediaType actualContentType = mapTuple.v1(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java index c0a1e0da00..d31d000837 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java @@ -11,7 +11,7 @@ /** * Factory for creating {@link PerFieldDerivedVectorInjector} instances. */ -public class PerFieldDerivedVectorInjectorFactory { +class PerFieldDerivedVectorInjectorFactory { /** * Create a {@link PerFieldDerivedVectorInjector} instance based on information in field info.