diff --git a/CHANGELOG.md b/CHANGELOG.md index d2508a05e..1385fe1fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add a new build mode, `FAISS_OPT_LEVEL=avx512_spr`, which enables the use of advanced AVX-512 instructions introduced with Intel(R) Sapphire Rapids (#2404)[https://github.com/opensearch-project/k-NN/pull/2404] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] - Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345] - +- Add derived source feature for vector fields (#2449)[https://github.com/opensearch-project/k-NN/pull/2449] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/build.gradle b/build.gradle index 76c6d774d..17552cf97 100644 --- a/build.gradle +++ b/build.gradle @@ -390,6 +390,7 @@ integTest { systemProperty("https", is_https) systemProperty("user", user) systemProperty("password", password) + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -451,6 +452,7 @@ task integTestRemote(type: RestIntegTestTask) { systemProperty 'cluster.number_of_nodes', "${_numNodes}" systemProperty 'tests.security.manager', 'false' + systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) // Run tests with remote cluster only if rest case is defined if (System.getProperty("tests.rest.cluster") != null) { diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 170cfabbe..7939837aa 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -162,4 +162,8 @@ public class KNNConstants { public static final String MODE_PARAMETER = "mode"; public static final String COMPRESSION_LEVEL_PARAMETER = "compression_level"; + + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY = "knn-derived-source-enabled"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE = "true"; + public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_FALSE_VALUE = "false"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 8442af764..34ab65b61 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -93,6 +93,7 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; + public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; /** * Default setting values @@ -269,6 +270,14 @@ public class KNNSettings { Setting.Property.Dynamic ); + public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( + KNN_DERIVED_SOURCE_ENABLED, + false, + IndexScope, + Final, + UnmodifiableOnRestore + ); + /** * This setting identifies KNN index. */ @@ -518,6 +527,9 @@ private Setting getSetting(String key) { if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; } + if (KNN_DERIVED_SOURCE_ENABLED.equals(key)) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING; + } throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -543,7 +555,8 @@ public List> getSettings() { KNN_FAISS_AVX512_SPR_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, + KNN_DERIVED_SOURCE_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -581,6 +594,14 @@ public static boolean isFaissAVX2Disabled() { } } + /** + * check this index enabled/disabled derived source + * @param settings Settings + */ + public static boolean isKNNDerivedSourceEnabled(Settings settings) { + return KNN_DERIVED_SOURCE_ENABLED_SETTING.get(settings); + } + public static boolean isFaissAVX512Disabled() { return Booleans.parseBoolean( Objects.requireNonNullElse( diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java new file mode 100644 index 000000000..e60b82b2e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.opensearch.common.Nullable; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + +@AllArgsConstructor +public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { + + private final StoredFieldsFormat delegate; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + // IMPORTANT Do not rely on this for the reader, it will be null if SPI is used + @Nullable + private final MapperService mapperService; + + @Override + public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) + throws IOException { + List derivedVectorFields = null; + for (FieldInfo fieldInfo : fieldInfos) { + if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) { + // Lazily initialize the list of fields + if (derivedVectorFields == null) { + derivedVectorFields = new ArrayList<>(); + } + derivedVectorFields.add(fieldInfo); + } + } + // If no fields have it enabled, we can just short-circuit and return the delegate's fieldReader + if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { + return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); + } + return new DerivedSourceStoredFieldsReader( + delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), + derivedVectorFields, + derivedSourceReadersSupplier, + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + ); + } + + @Override + public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentInfo, IOContext ioContext) throws IOException { + StoredFieldsWriter delegateWriter = delegate.fieldsWriter(directory, segmentInfo, ioContext); + if (mapperService != null && KNNSettings.isKNNDerivedSourceEnabled(mapperService.getIndexSettings().getSettings())) { + List vectorFieldTypes = new ArrayList<>(); + for (MappedFieldType fieldType : mapperService.fieldTypes()) { + if (fieldType instanceof KNNVectorFieldType) { + vectorFieldTypes.add(fieldType.name()); + } + } + if (vectorFieldTypes.isEmpty() == false) { + return new DerivedSourceStoredFieldsWriter(delegateWriter, vectorFieldTypes); + } + } + return delegateWriter; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java new file mode 100644 index 000000000..6c1ade140 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.StoredFieldVisitor; +import org.apache.lucene.util.IOUtils; +import org.opensearch.index.fieldvisitor.FieldsVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; + +import java.io.IOException; +import java.util.List; + +public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { + private final StoredFieldsReader delegate; + private final List derivedVectorFields; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final SegmentReadState segmentReadState; + private final boolean shouldInject; + + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + /** + * + * @param delegate delegate StoredFieldsReader + * @param derivedVectorFields List of fields that are derived source fields + * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param segmentReadState SegmentReadState for the segment + * @throws IOException in case of I/O error + */ + public DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState + ) throws IOException { + this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + } + + private DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + boolean shouldInject + ) throws IOException { + this.delegate = delegate; + this.derivedVectorFields = derivedVectorFields; + this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.segmentReadState = segmentReadState; + this.shouldInject = shouldInject; + this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); + } + + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { + return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + } + + @Override + public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { + // If the visitor has explicitly indicated it does not need the fields, we should not inject them + boolean isVisitorNeedFields = true; + if (storedFieldVisitor instanceof FieldsVisitor) { + isVisitorNeedFields = derivedSourceVectorInjector.shouldInject( + ((FieldsVisitor) storedFieldVisitor).includes(), + ((FieldsVisitor) storedFieldVisitor).excludes() + ); + } + if (shouldInject && isVisitorNeedFields) { + delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector)); + return; + } + delegate.document(docId, storedFieldVisitor); + } + + @Override + public StoredFieldsReader clone() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.clone(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + shouldInject + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public void close() throws IOException { + IOUtils.close(delegate, derivedSourceVectorInjector); + } + + /** + * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, + * on merge we will end up just writing the source to disk + * + * @return Merged instance that wont inject by default + */ + @Override + public StoredFieldsReader getMergeInstance() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.getMergeInstance(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + false + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java new file mode 100644 index 000000000..576bc9e98 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@RequiredArgsConstructor +public class DerivedSourceStoredFieldsWriter extends StoredFieldsWriter { + + private final StoredFieldsWriter delegate; + private final List vectorFieldTypes; + + @Override + public void startDocument() throws IOException { + delegate.startDocument(); + } + + @Override + public void writeField(FieldInfo fieldInfo, int i) throws IOException { + delegate.writeField(fieldInfo, i); + } + + @Override + public void writeField(FieldInfo fieldInfo, long l) throws IOException { + delegate.writeField(fieldInfo, l); + } + + @Override + public void writeField(FieldInfo fieldInfo, float v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo fieldInfo, double v) throws IOException { + delegate.writeField(fieldInfo, v); + } + + @Override + public void writeField(FieldInfo info, DataInput value, int length) throws IOException { + delegate.writeField(info, value, length); + } + + @Override + public int merge(MergeState mergeState) throws IOException { + return delegate.merge(mergeState); + } + + @Override + public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOException { + // Parse out the vectors from the source + if (Objects.equals(fieldInfo.name, SourceFieldMapper.NAME) && !vectorFieldTypes.isEmpty()) { + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), + true, + MediaTypeRegistry.JSON + ); + Map filteredSource = XContentMapValues.filter(null, vectorFieldTypes.toArray(new String[0])) + .apply(mapTuple.v2()); + BytesStreamOutput bStream = new BytesStreamOutput(); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(filteredSource); + builder.close(); + BytesReference bytesReference = bStream.bytes(); + delegate.writeField(fieldInfo, bytesReference.toBytesRef()); + return; + } + delegate.writeField(fieldInfo, bytesRef); + } + + @Override + public void writeField(FieldInfo fieldInfo, String s) throws IOException { + delegate.writeField(fieldInfo, s); + } + + @Override + public void finishDocument() throws IOException { + delegate.finishDocument(); + } + + @Override + public void finish(int i) throws IOException { + delegate.finish(i); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public long ramBytesUsed() { + return delegate.ramBytesUsed(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index a370197ec..b8a5e6a12 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -11,9 +11,12 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; /** * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 @@ -23,11 +26,13 @@ public class KNN9120Codec extends FilterCodec { private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final MapperService mapperService; + /** * No arg constructor that uses Lucene99 as the delegate */ public KNN9120Codec() { - this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null); } /** @@ -38,10 +43,11 @@ public KNN9120Codec() { * @param knnVectorsFormat per field format for KnnVector */ @Builder - protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) { super(VERSION.getCodecName(), delegate); knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; + this.mapperService = mapperService; } @Override @@ -58,4 +64,32 @@ public CompoundFormat compoundFormat() { public KnnVectorsFormat knnVectorsFormat() { return perFieldKnnVectorsFormat; } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { + if (segmentReadState.fieldInfos.hasVectorValues()) { + return knnVectorsFormat().fieldsReader(segmentReadState); + } + return null; + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasDocValues()) { + return docValuesFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasPostings()) { + return postingsFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState -> { + if (segmentReadState.fieldInfos.hasNorms()) { + return normsFormat().normsProducer(segmentReadState); + } + return null; + })); + return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 3df040785..6af6591f6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -126,6 +126,7 @@ public enum KNNCodecVersion { (userCodec, mapperService) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .mapperService(mapperService) .build(), KNN9120Codec::new ); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java new file mode 100644 index 000000000..123b718a4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaderSupplier.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +@FunctionalInterface +public interface DerivedSourceReaderSupplier { + R apply(SegmentReadState segmentReadState) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java new file mode 100644 index 000000000..c7e472e60 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.util.IOUtils; +import org.opensearch.common.Nullable; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Class holds the readers necessary to implement derived source. Important to note that if a segment does not have + * any of these fields, the values will be null. Caller needs to check if these are null before using. + */ +@RequiredArgsConstructor +@Getter +public class DerivedSourceReaders implements Closeable { + @Nullable + private final KnnVectorsReader knnVectorsReader; + @Nullable + private final DocValuesProducer docValuesProducer; + @Nullable + private final FieldsProducer fieldsProducer; + @Nullable + private final NormsProducer normsProducer; + + @Override + public void close() throws IOException { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java new file mode 100644 index 000000000..2dafa3af9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.index.SegmentReadState; + +import java.io.IOException; + +/** + * Class encapsulates the suppliers to give the {@link DerivedSourceReaders} from particular formats needed to implement + * derived source. More specifically, given a {@link org.apache.lucene.index.SegmentReadState}, this class will provide + * the correct format reader for that segment. + */ +@RequiredArgsConstructor +public class DerivedSourceReadersSupplier { + private final DerivedSourceReaderSupplier knnVectorsReaderSupplier; + private final DerivedSourceReaderSupplier docValuesProducerSupplier; + private final DerivedSourceReaderSupplier fieldsProducerSupplier; + private final DerivedSourceReaderSupplier normsProducer; + + /** + * Get the readers for the segment + * + * @param state SegmentReadState + * @return DerivedSourceReaders + * @throws IOException in case of I/O error + */ + public DerivedSourceReaders getReaders(SegmentReadState state) throws IOException { + return new DerivedSourceReaders( + knnVectorsReaderSupplier.apply(state), + docValuesProducerSupplier.apply(state), + fieldsProducerSupplier.apply(state), + normsProducer.apply(state) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java new file mode 100644 index 000000000..9610eff68 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitor.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.AllArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.index.mapper.SourceFieldMapper; + +import java.io.IOException; + +/** + * Custom {@link StoredFieldVisitor} that wraps an upstream delegate visitor in order to transparently inject derived + * source vector fields into the document. After the source is modified, it is forwarded to the delegate. + */ +@AllArgsConstructor +public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor { + + private final StoredFieldVisitor delegate; + private final Integer documentId; + private final DerivedSourceVectorInjector derivedSourceVectorInjector; + + @Override + public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { + if (fieldInfo.name.equals(SourceFieldMapper.NAME)) { + delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value)); + return; + } + delegate.binaryField(fieldInfo, value); + } + + @Override + public Status needsField(FieldInfo fieldInfo) throws IOException { + return delegate.needsField(fieldInfo); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java new file mode 100644 index 000000000..d3b1fe846 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.util.IOUtils; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative + * format readers and information about the fields to inject vectors into the source. + */ +@Log4j2 +public class DerivedSourceVectorInjector implements Closeable { + + private final DerivedSourceReaders derivedSourceReaders; + private final List perFieldDerivedVectorInjectors; + private final Set fieldNames; + + /** + * Constructor for DerivedSourceVectorInjector. + * + * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param segmentReadState Segment read state + * @param fieldsToInjectVector List of fields to inject vectors into + */ + public DerivedSourceVectorInjector( + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + List fieldsToInjectVector + ) throws IOException { + this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + this.perFieldDerivedVectorInjectors = new ArrayList<>(); + this.fieldNames = new HashSet<>(); + for (FieldInfo fieldInfo : fieldsToInjectVector) { + this.perFieldDerivedVectorInjectors.add( + PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState) + ); + this.fieldNames.add(fieldInfo.name); + } + } + + /** + * Given a docId and the source of that doc as bytes, add all the necessary vector fields into the source. + * + * @param docId doc id of the document + * @param sourceAsBytes source of document as bytes + * @return byte array of the source with the vector fields added + * @throws IOException if there is an issue reading from the formats + */ + public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { + // Reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 + // Deserialize the source into a modifiable map + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + // Have to create a copy of the map here to ensure that is mutable + Map sourceAsMap = new HashMap<>(mapTuple.v2()); + + // For each vector field, add in the source. The per field injectors are responsible for skipping if + // the field is not present. + for (PerFieldDerivedVectorInjector vectorInjector : perFieldDerivedVectorInjectors) { + vectorInjector.inject(docId, sourceAsMap); + } + + // At this point, we can serialize the modified source map + // Setting to 1024 based on + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourcePhase.java#L106 + BytesStreamOutput bStream = new BytesStreamOutput(1024); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + /** + * Whether or not to inject vectors based on what fields are explicitly required + * + * @param includes List of fields that are required to be injected + * @param excludes List of fields that are not required to be injected + * @return true if vectors should be injected, false otherwise + */ + public boolean shouldInject(String[] includes, String[] excludes) { + // If any of the vector fields are explicitly required we should inject + if (includes != null && includes != Strings.EMPTY_ARRAY) { + for (String includedField : includes) { + if (fieldNames.contains(includedField)) { + return true; + } + } + } + + // If all of the vector fields are explicitly excluded we should not inject + if (excludes != null && excludes != Strings.EMPTY_ARRAY) { + int excludedVectorFieldCount = 0; + for (String excludedField : excludes) { + if (fieldNames.contains(excludedField)) { + excludedVectorFieldCount++; + } + } + // Inject if we havent excluded all of the fields + return excludedVectorFieldCount < fieldNames.size(); + } + return true; + } + + @Override + public void close() throws IOException { + IOUtils.close(derivedSourceReaders); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java new file mode 100644 index 000000000..016b4153a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -0,0 +1,276 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BytesRef; +import org.opensearch.index.mapper.FieldNamesFieldMapper; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +@Log4j2 +@AllArgsConstructor +public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo childFieldInfo; + private final DerivedSourceReaders derivedSourceReaders; + private final SegmentReadState segmentReadState; + + @Override + public void inject(int parentDocId, Map sourceAsMap) throws IOException { + // If the parent has the field, then it is just an object field. + int lowestDocIdForFieldWithParentAsOffset = getLowestDocIdForField(childFieldInfo.name, parentDocId); + if (lowestDocIdForFieldWithParentAsOffset == parentDocId) { + injectObject(parentDocId, sourceAsMap); + return; + } + + // Setup the iterator. Return if no parent + String childFieldName = ParentChildHelper.getChildField(childFieldInfo.name); + String parentFieldName = ParentChildHelper.getParentField(childFieldInfo.name); + if (parentFieldName == null) { + return; + } + NestedPerFieldParentToDocIdIterator nestedPerFieldParentToDocIdIterator = new NestedPerFieldParentToDocIdIterator( + childFieldInfo, + segmentReadState, + derivedSourceReaders, + parentDocId + ); + + if (nestedPerFieldParentToDocIdIterator.numChildren() == 0) { + return; + } + + // Initializes the parent field so that there is a list to put each of the children + Object originalParentValue = sourceAsMap.get(parentFieldName); + List> reconstructedSource; + if (originalParentValue instanceof Map) { + reconstructedSource = new ArrayList<>(List.of((Map) originalParentValue)); + } else { + reconstructedSource = (List>) originalParentValue; + } + + // Contains the docIds of existing objects in the map in order. This is used to help figure out the best play + // to put back the vectors + List positions = mapObjectsToPositionInNestedList( + reconstructedSource, + nestedPerFieldParentToDocIdIterator.firstChild(), + parentDocId + ); + + // Finally, inject children for the document into the source. This code is non-trivial because filtering out + // the vectors during write could mean that children docs disappear from the source. So, to properly put + // everything back, we need to figure out where the existing fields in the original map to + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + int offsetPositionsIndex = 0; + while (nestedPerFieldParentToDocIdIterator.nextChild() != NO_MORE_DOCS) { + // If the child does not have a vector, vectValues advance will advance past child to the next matching + // docId. So, we need to ensure that doing this does not pass the parent docId. + if (nestedPerFieldParentToDocIdIterator.childId() > vectorValues.docId()) { + vectorValues.advance(nestedPerFieldParentToDocIdIterator.childId()); + } + if (vectorValues.docId() != nestedPerFieldParentToDocIdIterator.childId()) { + continue; + } + + int docId = nestedPerFieldParentToDocIdIterator.childId(); + boolean isInsert = true; + int position = positions.size(); // by default we insert it at the end + for (int i = offsetPositionsIndex; i < positions.size(); i++) { + if (docId < positions.get(i)) { + position = i; + break; + } + if (docId == positions.get(i)) { + isInsert = false; + position = i; + break; + } + } + + if (isInsert) { + reconstructedSource.add(position, new HashMap<>()); + positions.add(position, docId); + } + reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); + offsetPositionsIndex = position + 1; + } + sourceAsMap.put(parentFieldName, reconstructedSource); + } + + private void injectObject(int docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues( + childFieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + if (vectorValues.docId() != docId && vectorValues.advance(docId) != docId) { + return; + } + String[] fields = ParentChildHelper.splitPath(childFieldInfo.name); + Map currentMap = sourceAsMap; + for (int i = 0; i < fields.length - 1; i++) { + String field = fields[i]; + currentMap = (Map) currentMap.computeIfAbsent(field, k -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectorValues.getVector()); + } + + /** + * Given a list of maps, map each map to a position in the nested list. This is used to help figure out where to put + * the vectors back in the source. + * + * @param originals list of maps + * @param firstChild first child docId + * @param parent parent docId + * @return list of positions in the nested list + * @throws IOException if there is an issue reading from the formats + */ + private List mapObjectsToPositionInNestedList(List> originals, int firstChild, int parent) + throws IOException { + List positions = new ArrayList<>(); + int offset = firstChild; + for (Map docWithFields : originals) { + int fieldMapping = mapToDocId(docWithFields, offset, parent); + assert fieldMapping != -1; + positions.add(fieldMapping); + offset = fieldMapping + 1; + } + return positions; + } + + /** + * Given a doc as a map and the offset it has to be, find the ordinal of the first field that is greater than the + * offset. + * + * @param doc doc to find the ordinal for + * @param offset offset to start searching from + * @return id of the first field that is greater than the offset + * @throws IOException if there is an issue reading from the formats + */ + private int mapToDocId(Map doc, int offset, int parent) throws IOException { + // For all the fields, we look for the first doc that matches any of the fields. + int position = NO_MORE_DOCS; + for (String key : doc.keySet()) { + position = getLowestDocIdForField(ParentChildHelper.constructSiblingField(childFieldInfo.name, key), offset); + if (position < parent) { + break; + } + } + + // Advancing past the parent means something went wrong + assert position < parent; + return position; + } + + /** + * Get the lowest docId for a field that is greater than the offset. + * + * @param fieldToMatch field to find the lowest docId for + * @param offset offset to start searching from + * @return lowest docId for the field that is greater than the offset. Returns {@link DocIdSetIterator#NO_MORE_DOCS} if doc cannot be found + * @throws IOException if there is an issue reading from the formats + */ + private int getLowestDocIdForField(String fieldToMatch, int offset) throws IOException { + // This method implementation is inspired by the FieldExistsQuery in Lucene and the FieldNamesMapper in + // Opensearch. We first mimic the logic in the FieldExistsQuery in order to identify the docId of the nested + // doc. If that fails, we rely on + // References: + // 1. https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java#L170-L218. + // 2. + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/mapper/FieldMapper.java#L316-L324 + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(fieldToMatch); + + if (fieldInfo == null) { + return NO_MORE_DOCS; + } + + DocIdSetIterator iterator = null; + if (fieldInfo.hasNorms() && derivedSourceReaders.getNormsProducer() != null) { // the field indexes norms + iterator = derivedSourceReaders.getNormsProducer().getNorms(fieldInfo); + } else if (fieldInfo.getVectorDimension() != 0 && derivedSourceReaders.getKnnVectorsReader() != null) { // the field indexes vectors + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + iterator = derivedSourceReaders.getKnnVectorsReader().getFloatVectorValues(fieldInfo.name); + break; + case BYTE: + iterator = derivedSourceReaders.getKnnVectorsReader().getByteVectorValues(fieldInfo.name); + break; + } + } else if (fieldInfo.getDocValuesType() != DocValuesType.NONE && derivedSourceReaders.getDocValuesProducer() != null) { // the field + // indexes + // doc + // values + switch (fieldInfo.getDocValuesType()) { + case NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getNumeric(fieldInfo); + break; + case BINARY: + iterator = derivedSourceReaders.getDocValuesProducer().getBinary(fieldInfo); + break; + case SORTED: + iterator = derivedSourceReaders.getDocValuesProducer().getSorted(fieldInfo); + break; + case SORTED_NUMERIC: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedNumeric(fieldInfo); + break; + case SORTED_SET: + iterator = derivedSourceReaders.getDocValuesProducer().getSortedSet(fieldInfo); + break; + case NONE: + default: + throw new AssertionError(); + } + } + if (iterator != null) { + return iterator.advance(offset); + } + + // Check the field names field type for matches + if (derivedSourceReaders.getFieldsProducer() == null) { + return NO_MORE_DOCS; + } + Terms terms = derivedSourceReaders.getFieldsProducer().terms(FieldNamesFieldMapper.NAME); + if (terms == null) { + return NO_MORE_DOCS; + } + TermsEnum fieldNameFieldsTerms = terms.iterator(); + BytesRef fieldToMatchRef = new BytesRef(fieldInfo.name); + PostingsEnum postingsEnum = null; + while (fieldNameFieldsTerms.next() != null) { + BytesRef currentTerm = fieldNameFieldsTerms.term(); + if (currentTerm.bytesEquals(fieldToMatchRef)) { + postingsEnum = fieldNameFieldsTerms.postings(null); + break; + } + } + if (postingsEnum == null) { + return NO_MORE_DOCS; + } + return postingsEnum.advance(offset); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java new file mode 100644 index 000000000..d2bc1a32f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldParentToDocIdIterator.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Iterator over the children documents of a particular parent + */ +public class NestedPerFieldParentToDocIdIterator { + + private final FieldInfo childFieldInfo; + private final SegmentReadState segmentReadState; + private final DerivedSourceReaders derivedSourceReaders; + private final int parentDocId; + private final int previousParentDocId; + private final List children; + private int currentChild; + + /** + * + * @param childFieldInfo FieldInfo for the child field + * @param segmentReadState SegmentReadState for the segment + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @param parentDocId Parent docId of the parent + * @throws IOException if there is an error reading the parent docId + */ + public NestedPerFieldParentToDocIdIterator( + FieldInfo childFieldInfo, + SegmentReadState segmentReadState, + DerivedSourceReaders derivedSourceReaders, + int parentDocId + ) throws IOException { + this.childFieldInfo = childFieldInfo; + this.segmentReadState = segmentReadState; + this.derivedSourceReaders = derivedSourceReaders; + this.parentDocId = parentDocId; + this.previousParentDocId = previousParent(); + this.children = getChildren(); + this.currentChild = -1; + } + + /** + * For the given parent get its first child offset + * + * @return the first child offset. If there are no children, just return NO_MORE_DOCS + */ + public int firstChild() { + if (parentDocId - previousParentDocId == 1) { + return NO_MORE_DOCS; + } + return previousParentDocId + 1; + } + + /** + * Get the next child for this parent + * + * @return the next child docId. If this has not been set, return -1. If there are no more children, return + * NO_MORE_DOCS + */ + public int nextChild() { + currentChild++; + if (currentChild >= children.size()) { + return NO_MORE_DOCS; + } + return children.get(currentChild); + } + + /** + * Get the current child for this parent + * + * @return the current child docId. If this has not been set, return -1 + */ + public int childId() { + return children.get(currentChild); + } + + /** + * + * @return the number of children for this parent + */ + public int numChildren() { + return children.size(); + } + + /** + * For parentDocId of this class, find the one just before it to be used for matching children. + * + * @return the parent docId just before the parentDocId. -1 if none exist + * @throws IOException if there is an error reading the parent docId + */ + private int previousParent() throws IOException { + // TODO: In the future this needs to be generalized to handle multiple levels of nesting + // For now, for non-nested docs, the primary_term field can be used to identify root level docs. For reference: + // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/search/fetch/subphase/SeqNoPrimaryTermPhase.java#L72 + // https://github.com/opensearch-project/OpenSearch/blob/3032bef54d502836789ea438f464ae0b1ba978b2/server/src/main/java/org/opensearch/index/mapper/SeqNoFieldMapper.java#L206-L230 + // We use it here to identify the previous parent to the current parent to get a range on the children documents + FieldInfo seqTermsFieldInfo = segmentReadState.fieldInfos.fieldInfo("_primary_term"); + NumericDocValues numericDocValues = derivedSourceReaders.getDocValuesProducer().getNumeric(seqTermsFieldInfo); + int previousParentDocId = -1; + while (numericDocValues.nextDoc() != NO_MORE_DOCS) { + if (numericDocValues.docID() >= parentDocId) { + break; + } + previousParentDocId = numericDocValues.docID(); + } + return previousParentDocId; + } + + /** + * Get all the children that match the parent path for the _nested_field + * + * @return list of children that match the parent path + * @throws IOException if there is an error reading the children + */ + private List getChildren() throws IOException { + if (this.parentDocId - this.previousParentDocId <= 1) { + return Collections.emptyList(); + } + + // First, we need to get the currect PostingsEnum for the key as _nested_path and the value the actual parent + // path. + String childField = childFieldInfo.name; + String parentField = ParentChildHelper.getParentField(childField); + + Terms terms = derivedSourceReaders.getFieldsProducer().terms("_nested_path"); + if (terms == null) { + return Collections.emptyList(); + } + TermsEnum nestedFieldsTerms = terms.iterator(); + BytesRef childPathRef = new BytesRef(parentField); + PostingsEnum postingsEnum = null; + while (nestedFieldsTerms.next() != null) { + BytesRef currentTerm = nestedFieldsTerms.term(); + if (currentTerm.bytesEquals(childPathRef)) { + postingsEnum = nestedFieldsTerms.postings(null); + break; + } + } + + // Next, get all the children that match this parent path. If none were found, return an empty list + if (postingsEnum == null) { + return Collections.emptyList(); + } + List children = new ArrayList<>(); + postingsEnum.advance(previousParentDocId + 1); + while (postingsEnum.docID() != NO_MORE_DOCS && postingsEnum.docID() < parentDocId) { + if (postingsEnum.freq() > 0) { + children.add(postingsEnum.docID()); + } + postingsEnum.nextDoc(); + } + + return children; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java new file mode 100644 index 000000000..534cf93d7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +/** + * Helper class for working with nested fields. + */ +public class ParentChildHelper { + + /** + * Given a nested field path, return the path of the parent field. For instance if the field is "parent.to.child", + * this would return "parent.to". + * + * @param field nested field path + * @return parent field path without the child + */ + public static String getParentField(String field) { + int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } + return field.substring(0, lastDot); + } + + /** + * Given a nested field path, return the child field. For instance if the field is "parent.to.child", this would + * return "child". + * + * @param field nested field path + * @return child field path without the parent path + */ + public static String getChildField(String field) { + int lastDot = field.lastIndexOf('.'); + return field.substring(lastDot + 1); + } + + /** + * Construct a sibling field path. For instance, if the field is "parent.to.child" and the sibling is "sibling", this + * would return "parent.to.sibling". + * + * @param field nested field path + * @param sibling sibling field + * @return sibling field path + */ + public static String constructSiblingField(String field, String sibling) { + return getParentField(field) + "." + sibling; + } + + /** + * Split a nested field path into an array of strings. For instance, if the field is "parent.to.child", this would + * return ["parent", "to", "child"]. + * + * @param field nested field path + * @return array of strings representing the nested field path + */ + public static String[] splitPath(String field) { + return field.split("\\."); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java new file mode 100644 index 000000000..b0bc5930c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjector.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import java.io.IOException; +import java.util.Map; + +/** + * Interface for injecting derived vectors into a source map per field. + */ +public interface PerFieldDerivedVectorInjector { + + /** + * Injects the derived vector for this field into the sourceAsMap. Implementing classes must handle the case where + * a document does not have a value for their field. + * + * @param docId Document ID + * @param sourceAsMap Source as map + * @throws IOException if there is an issue reading from the formats + */ + void inject(int docId, Map sourceAsMap) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java new file mode 100644 index 000000000..d31d00083 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; + +/** + * Factory for creating {@link PerFieldDerivedVectorInjector} instances. + */ +class PerFieldDerivedVectorInjectorFactory { + + /** + * Create a {@link PerFieldDerivedVectorInjector} instance based on information in field info. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + * @return PerFieldDerivedVectorInjector instance + */ + public static PerFieldDerivedVectorInjector create( + FieldInfo fieldInfo, + DerivedSourceReaders derivedSourceReaders, + SegmentReadState segmentReadState + ) { + // Nested case + if (ParentChildHelper.getParentField(fieldInfo.name) != null) { + return new NestedPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders, segmentReadState); + } + + // Non-nested case + return new RootPerFieldDerivedVectorInjector(fieldInfo, derivedSourceReaders); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java new file mode 100644 index 000000000..430fd24ae --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.CheckedSupplier; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.Map; + +/** + * {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields). + */ +class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + + private final FieldInfo fieldInfo; + private final CheckedSupplier, IOException> vectorValuesSupplier; + + /** + * Constructor for RootPerFieldDerivedVectorInjector. + * + * @param fieldInfo FieldInfo for the field to create the injector for + * @param derivedSourceReaders {@link DerivedSourceReaders} instance + */ + public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReaders derivedSourceReaders) { + this.fieldInfo = fieldInfo; + this.vectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues( + fieldInfo, + derivedSourceReaders.getDocValuesProducer(), + derivedSourceReaders.getKnnVectorsReader() + ); + } + + @Override + public void inject(int docId, Map sourceAsMap) throws IOException { + KNNVectorValues vectorValues = vectorValuesSupplier.get(); + if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { + sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 9f1ebcf01..68ea25a1f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -14,6 +14,9 @@ import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; + /** * Mapper used when you dont want to build an underlying KNN struct - you just want to * store vectors as doc values @@ -32,7 +35,8 @@ public static FlatVectorFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -49,7 +53,8 @@ public static FlatVectorFieldMapper createFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -62,7 +67,8 @@ private FlatVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -73,13 +79,17 @@ private FlatVectorFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.setDocValuesType(DocValuesType.BINARY); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.fieldType.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 99c6ebe2a..be485847c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -93,6 +93,7 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { */ public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); @@ -200,13 +201,15 @@ public Builder( ModelDao modelDao, Version indexCreatedVersion, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalParameters + OriginalMappingParameters originalParameters, + boolean isDerivedSourceEnabled ) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; this.knnMethodConfigContext = knnMethodConfigContext; this.originalParameters = originalParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } @Override @@ -258,7 +261,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { modelDao, indexCreatedVersion, originalParameters, - knnMethodConfigContext + knnMethodConfigContext, + isDerivedSourceEnabled ); } @@ -280,7 +284,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -301,7 +306,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { metaValue, knnMethodConfigContext, createLuceneFieldMapperInput, - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -315,7 +321,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.getValue(), hasDocValues.getValue(), - originalParameters + originalParameters, + isDerivedSourceEnabled ); } @@ -363,7 +370,8 @@ public Mapper.Builder parse(String name, Map node, ParserCont modelDaoSupplier.get(), parserContext.indexVersionCreated(), null, - null + null, + KNNSettings.isKNNDerivedSourceEnabled(parserContext.getSettings()) ); builder.parse(name, parserContext, node); builder.setOriginalParameters(new OriginalMappingParameters(builder)); @@ -569,6 +577,7 @@ static boolean useKNNMethodContextFromLegacy(Builder builder, Mapper.TypeParser. // values of KNN engine Algorithms hyperparameters. protected Version indexCreatedVersion; protected Explicit ignoreMalformed; + protected final boolean isDerivedSourceEnabled; protected boolean stored; protected boolean hasDocValues; protected VectorDataType vectorDataType; @@ -589,7 +598,8 @@ public KNNVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; @@ -599,6 +609,7 @@ public KNNVectorFieldMapper( updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; this.originalMappingParameters = originalMappingParameters; + this.isDerivedSourceEnabled = isDerivedSourceEnabled; } public KNNVectorFieldMapper clone() { @@ -831,7 +842,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { modelDao, indexCreatedVersion, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ).init(this); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4ceb9b4b2..49cd02d5b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -27,6 +27,8 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; @@ -48,7 +50,8 @@ static LuceneFieldMapper createFieldMapper( Map metaValue, KNNMethodConfigContext knnMethodConfigContext, CreateLuceneFieldMapperInput createLuceneFieldMapperInput, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -82,14 +85,21 @@ public Version getIndexCreatedVersion() { } ); - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext, originalMappingParameters); + return new LuceneFieldMapper( + mappedFieldType, + createLuceneFieldMapperInput, + knnMethodConfigContext, + originalMappingParameters, + isDerivedSourceEnabled + ); } private LuceneFieldMapper( final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( input.getName(), @@ -100,7 +110,8 @@ private LuceneFieldMapper( input.isStored(), input.isHasDocValues(), knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -117,6 +128,12 @@ private LuceneFieldMapper( this.vectorFieldType = null; } + if (isDerivedSourceEnabled) { + this.fieldType = new FieldType(this.fieldType); + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + this.fieldType.freeze(); + } + KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 814bc4f63..a2635b195 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -51,7 +53,8 @@ public static MethodFieldMapper createFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); @@ -104,7 +107,8 @@ public Version getIndexCreatedVersion() { stored, hasDocValues, knnMethodConfigContext, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -117,7 +121,8 @@ private MethodFieldMapper( boolean stored, boolean hasDocValues, KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( @@ -129,7 +134,8 @@ private MethodFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); @@ -151,7 +157,9 @@ private MethodFieldMapper( this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); - + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } try { this.fieldType.putAttribute( PARAMETERS, diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index d472090fc..ae912aa41 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -28,6 +28,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY; +import static org.opensearch.knn.common.KNNConstants.DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; @@ -60,7 +62,8 @@ public static ModelFieldMapper createFieldMapper( ModelDao modelDao, Version indexCreatedVersion, OriginalMappingParameters originalMappingParameters, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + boolean isDerivedSourceEnabled ) { final KNNMethodContext knnMethodContext = originalMappingParameters.getKnnMethodContext(); @@ -134,7 +137,8 @@ private void initFromModelMetadata() { hasDocValues, modelDao, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); } @@ -148,7 +152,8 @@ private ModelFieldMapper( boolean hasDocValues, ModelDao modelDao, Version indexCreatedVersion, - OriginalMappingParameters originalMappingParameters + OriginalMappingParameters originalMappingParameters, + boolean isDerivedSourceEnabled ) { super( simpleName, @@ -159,7 +164,8 @@ private ModelFieldMapper( stored, hasDocValues, indexCreatedVersion, - originalMappingParameters + originalMappingParameters, + isDerivedSourceEnabled ); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); @@ -174,6 +180,9 @@ private ModelFieldMapper( this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); + if (isDerivedSourceEnabled) { + this.fieldType.putAttribute(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY, DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE); + } this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e217..9ae7f3842 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -72,6 +74,37 @@ public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); } + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} + * + * @param fieldInfo {@link FieldInfo} + * @param docValuesProducer {@link DocValuesProducer} + * @param knnVectorsReader {@link KnnVectorsReader} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues( + final FieldInfo fieldInfo, + final DocValuesProducer docValuesProducer, + final KnnVectorsReader knnVectorsReader + ) throws IOException { + final DocIdSetIterator docIdSetIterator; + if (fieldInfo.hasVectorValues() && knnVectorsReader != null) { + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + docIdSetIterator = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + docIdSetIterator = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); + } + } else if (docValuesProducer != null) { + docIdSetIterator = docValuesProducer.getBinary(fieldInfo); + } else { + throw new IllegalArgumentException("Field does not have vector values and DocValues"); + } + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + @SuppressWarnings("unchecked") private static KNNVectorValues getVectorValues( final VectorDataType vectorDataType, diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 49b15a0f4..52ad3ded3 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -123,7 +123,8 @@ public void testBuilder_getParameters() { modelDao, CURRENT, null, - new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()) + new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null, SpaceType.UNDEFINED.getValue()), + false ); assertEquals(10, builder.getParameters().size()); @@ -357,7 +358,7 @@ public void testTypeParser_withSpaceTypeAndMode_thenSuccess() throws IOException public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -1177,7 +1178,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1216,7 +1218,8 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT new Explicit<>(true, true), false, false, - originalMappingParameters + originalMappingParameters, + false ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -1288,7 +1291,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1330,7 +1334,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy modelDao, CURRENT, originalMappingParameters, - knnMethodConfigContext + knnMethodConfigContext, + false ); modelFieldMapper.parseCreateField(parseContext); @@ -1376,7 +1381,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1435,7 +1441,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { Collections.emptyMap(), knnMethodConfigContext, inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1482,7 +1489,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1532,7 +1540,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .dimension(TEST_DIMENSION) .build(), inputBuilder.build(), - originalMappingParameters + originalMappingParameters, + false ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1647,7 +1656,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithSQ_thenException() throws IOEx public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null, false); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1696,7 +1705,7 @@ public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { // IllegalArgumentException should be thrown. Exception e = assertThrows(IllegalArgumentException.class, () -> { - new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null).build(builderContext); + new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null, false).build(builderContext); }); assertTrue(e.getMessage(), e.getMessage().contains("Vector field name must not include")); } diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java new file mode 100644 index 000000000..08c191960 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -0,0 +1,1252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.primitives.Floats; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; + +/** + * Integration tests for derived source feature for vector fields. Currently, with derived source, there are + * a few gaps in functionality. Ignoring tests for now as feature is experimental. + */ +public class DerivedSourceIT extends KNNRestTestCase { + + private final static String NESTED_NAME = "test_nested"; + private final static String FIELD_NAME = "test_vector"; + private final int TEST_DIMENSION = 128; + private final int DOCS = 50; + + private static final Settings DERIVED_ENABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, true) + .build(); + private static final Settings DERIVED_DISABLED_SETTINGS = Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .put(KNNSettings.KNN_DERIVED_SOURCE_ENABLED, false) + .build(); + + /** + * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for + * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an + * index without derived source enabled (baseline). + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * } + * } + * } + * } + */ + @SneakyThrows + public void testFlatBaseCase() { + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing multiple flat fields. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "test_vector1": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + */ + @SneakyThrows + public void testMultiFlatFields() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject("text") + .field(TYPE, "text") + .endObject() + .endObject() + .endObject(); + String multiFieldMapping = builder.toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndMultFields( + context.indexName, + context.vectorFieldNames.get(0), + context.vectorFieldNames.get(1), + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME + "1", FIELD_NAME + "2")) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(multiFieldMapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ + public void testNestedSingleDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNested( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing single nested doc per parent doc. + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_nested" : { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128 + * }, + * "text": { + * "type": "text", + * }, + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testNestedMultiDocBasic() { + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + context.indexName, + context.vectorFieldNames.get(0), + NESTED_NAME + "." + "text", + context.docCount, + context.dimension, + 0.1f, + 5 + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(NESTED_NAME + "." + FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(nestedMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Test object (non-nested field) + * Test + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + * } + * } + * Baseline + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * { + * "properties": { + * "vector_field_1" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_1": { + * "properties" : { + * "vector_field_2" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * "path_2": { + * "properties" : { + * "vector_field_3" : { + * "type" : "knn_vector", + * "dimension" : 2 + * }, + * } + * } + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testObjectFieldTypes() { + String PATH_1_NAME = "path_1"; + String PATH_2_NAME = "path_2"; + + String objectFieldTypeMapping = XContentFactory.jsonBuilder() + .startObject() // 1-open + .startObject(PROPERTIES_FIELD) // 2-open + .startObject(FIELD_NAME + "1") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + + .startObject(PATH_1_NAME) + .startObject(PROPERTIES_FIELD) + + .startObject(FIELD_NAME + "2") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .startObject(PATH_2_NAME) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME + "3") + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsMultiFieldsWithSkips( + context.indexName, + context.vectorFieldNames, + List.of("text", PATH_1_NAME + "." + "text", PATH_1_NAME + "." + PATH_2_NAME + "." + "text"), + context.docCount, + context.dimension, + 0.1f + ); + refreshAllIndices(); + }) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames( + List.of( + FIELD_NAME + "1", + PATH_1_NAME + "." + FIELD_NAME + "2", + PATH_1_NAME + "." + PATH_2_NAME + "." + FIELD_NAME + "3" + ) + ) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(objectFieldTypeMapping) + .isNested(true) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Single method for running end to end tests for different index configurations for derived source. In general, + * flow of operations are + * + * @param indexConfigContexts {@link IndexConfigContext} + */ + @SneakyThrows + private void testDerivedSourceE2E(List indexConfigContexts) { + // Make sure there are 6 + assertEquals(6, indexConfigContexts.size()); + + // Prepare the indices by creating them and ingesting data into them + prepareOriginalIndices(indexConfigContexts); + + // Merging + testMerging(indexConfigContexts); + + // Update. Skipping update tests for nested docs for now. Will add in the future. + if (indexConfigContexts.get(0).isNested() == false) { + testUpdate(indexConfigContexts); + } + + // Delete + testDelete(indexConfigContexts); + + // Search + testSearch(indexConfigContexts); + + // Reindex + testReindex(indexConfigContexts); + } + + @SneakyThrows + private void prepareOriginalIndices(List indexConfigContexts) { + assertEquals(6, indexConfigContexts.size()); + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + createKnnIndex(derivedSourceEnabledContext.indexName, derivedSourceEnabledContext.settings, derivedSourceEnabledContext.mapping); + createKnnIndex(derivedSourceDisabledContext.indexName, derivedSourceDisabledContext.settings, derivedSourceDisabledContext.mapping); + derivedSourceEnabledContext.indexIngestor.accept(derivedSourceEnabledContext); + derivedSourceDisabledContext.indexIngestor.accept(derivedSourceDisabledContext); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + derivedSourceDisabledContext.indexName, + derivedSourceEnabledContext.indexName + ); + } + + @SneakyThrows + private void testMerging(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 10); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 10); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testUpdate(List indexConfigContexts) { + // Random variables + int docWithVectorUpdate = DOCS - 4; + int docWithVectorRemoval = 1; + int docWithVectorUpdateFromAPI = 2; + int docWithUpdateByQuery = 7; + + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + float[] updateVector = randomFloatVector(derivedSourceDisabledContext.dimension); + + // Update via POST //_doc/ + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDoc( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdate), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Sets the doc to an empty doc + setDocToEmpty(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorRemoval)); + setDocToEmpty(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorRemoval)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Use update API + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocWithUpdateAPI( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithVectorUpdateFromAPI), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Update by query + for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceEnabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { + updateKnnDocByQuery( + originalIndexNameDerivedSourceDisabled, + String.valueOf(docWithUpdateByQuery), + fieldName, + Floats.asList(updateVector).toArray() + ); + } + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testSearch(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + + // Default - all fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, true, null, null); + + // Default - no fields should be there + validateSearch(originalIndexNameDerivedSourceEnabled, derivedSourceEnabledContext.docCount, false, null, null); + + // Exclude all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + null, + derivedSourceEnabledContext.vectorFieldNames + ); + + // Include all vectors + validateSearch( + originalIndexNameDerivedSourceEnabled, + derivedSourceEnabledContext.docCount, + true, + derivedSourceEnabledContext.vectorFieldNames, + null + ); + } + + @SneakyThrows + private void validateSearch(String indexName, int size, boolean isSourceEnabled, List includes, List excludes) { + // TODO: We need to figure out a way to enhance validation + QueryBuilder qb = new MatchAllQueryBuilder(); + Request request = new Request("POST", "/" + indexName + "/_search"); + + request.addParameter("size", Integer.toString(size)); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("query", qb); + if (isSourceEnabled == false) { + builder.field("_source", false); + } + if (includes != null) { + builder.startObject("_source"); + builder.startArray("includes"); + for (String include : includes) { + builder.value(include); + } + builder.endArray(); + builder.endObject(); + } + if (excludes != null) { + builder.startObject("_source"); + builder.startArray("excludes"); + for (String exclude : excludes) { + builder.value(exclude); + } + builder.endArray(); + builder.endObject(); + } + + builder.endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + List hits = parseSearchResponseHits(responseBody); + + assertNotEquals(0, hits.size()); + } + + @SneakyThrows + private void testDelete(List indexConfigContexts) { + int docToDelete = 8; + int docToDeleteByQuery = 11; + + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + + // Delete by API + deleteKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDelete)); + deleteKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDelete)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + + // Delete by query + deleteKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docToDeleteByQuery)); + deleteKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docToDeleteByQuery)); + refreshAllIndices(); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + originalIndexNameDerivedSourceEnabled + ); + } + + @SneakyThrows + private void testReindex(List indexConfigContexts) { + IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + IndexConfigContext reindexFromEnabledToEnabledContext = indexConfigContexts.get(2); + IndexConfigContext reindexFromEnabledToDisabledContext = indexConfigContexts.get(3); + IndexConfigContext reindexFromDisabledToEnabledContext = indexConfigContexts.get(4); + IndexConfigContext reindexFromDisabledToDisabledContext = indexConfigContexts.get(5); + + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + String reindexFromEnabledToEnabledIndexName = reindexFromEnabledToEnabledContext.indexName; + String reindexFromEnabledToDisabledIndexName = reindexFromEnabledToDisabledContext.indexName; + String reindexFromDisabledToEnabledIndexName = reindexFromDisabledToEnabledContext.indexName; + String reindexFromDisabledToDisabledIndexName = reindexFromDisabledToDisabledContext.indexName; + + createKnnIndex( + reindexFromEnabledToEnabledIndexName, + reindexFromEnabledToEnabledContext.getSettings(), + reindexFromEnabledToEnabledContext.getMapping() + ); + createKnnIndex( + reindexFromEnabledToDisabledIndexName, + reindexFromEnabledToDisabledContext.getSettings(), + reindexFromEnabledToDisabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToEnabledIndexName, + reindexFromDisabledToEnabledContext.getSettings(), + reindexFromDisabledToEnabledContext.getMapping() + ); + createKnnIndex( + reindexFromDisabledToDisabledIndexName, + reindexFromDisabledToDisabledContext.getSettings(), + reindexFromDisabledToDisabledContext.getMapping() + ); + refreshAllIndices(); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceEnabled, reindexFromEnabledToDisabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + reindex(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToDisabledIndexName); + + // Need to forcemerge before comparison + refreshAllIndices(); + forceMergeKnnIndex(originalIndexNameDerivedSourceEnabled, 1); + forceMergeKnnIndex(originalIndexNameDerivedSourceDisabled, 1); + refreshAllIndices(); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToDisabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToDisabledIndexName + ); + } + + @Builder + @Data + private static class IndexConfigContext { + String indexName; + List vectorFieldNames; + int dimension; + Settings settings; + String mapping; + boolean isNested; + int docCount; + CheckedConsumer indexIngestor; + } + + @SneakyThrows + private void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { + if (isExhaustive()) { + logger.info("Checking index bigger assertion because running in exhaustive mode"); + int expectedSmaller = indexSizeInBytes(expectedSmallerIndex); + int expectedBigger = indexSizeInBytes(expectedBiggerIndex); + assertTrue( + "Expected smaller index " + expectedSmaller + " was bigger than the expected bigger index:" + expectedBigger, + expectedSmaller < expectedBigger + ); + } else { + logger.info("Skipping index bigger assertion because not running in exhaustive mode"); + } + } + + private void assertDocsMatch(int docCount, String index1, String index2) { + for (int i = 0; i < docCount; i++) { + assertDocMatches(i + 1, index1, index2); + } + } + + @SneakyThrows + private void assertDocMatches(int docId, String index1, String index2) { + Map response1 = getKnnDoc(index1, String.valueOf(docId)); + Map response2 = getKnnDoc(index2, String.valueOf(docId)); + assertEquals("Docs do not match: " + docId, response1, response2); + } + + @SneakyThrows + private String createVectorNonNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + + @SneakyThrows + private String createVectorNestedMappings(final int dimension) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(NESTED_NAME) + .field(TYPE, "nested") + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 381b368c0..e206aa315 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -27,6 +27,7 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.derivedsource.ParentChildHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -70,6 +71,7 @@ import java.util.Objects; import java.util.Optional; import java.util.PriorityQueue; +import java.util.Random; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -171,6 +173,15 @@ public void cleanUpCache() throws Exception { clearCache(); } + /** + * Gives the ability for certain, more exhaustive checks, to be disabled by default + * + * @return If the test is running in exhaustive mode + */ + protected boolean isExhaustive() { + return Boolean.parseBoolean(System.getProperty("test.exhaustive", "false")); + } + /** * Create KNN Index with default settings */ @@ -274,6 +285,11 @@ protected Response performSearch(final String indexName, final String query, fin return response; } + protected List parseSearchResponseHits(String responseBody) throws IOException { + return (List) ((Map) createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map() + .get("hits")).get("hits"); + } + /** * Parse the response of KNN search into a List of KNNResults */ @@ -377,10 +393,6 @@ protected Double parseAggregationResponse(String responseBody, String aggregatio return Double.valueOf(String.valueOf(values.get("value"))); } - /** - * Parse the score from the KNN search response - */ - /** * Delete KNN index */ @@ -696,6 +708,28 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void addDocWithNestedNumericField(String index, String docId, String nestedFieldPath, long val) throws IOException { + String[] fieldParts = nestedFieldPath.split("\\."); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], val); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with multiple fields */ @@ -762,15 +796,90 @@ protected void addDocWithBinaryField(String index, String docId, String fieldNam */ protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject(); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** + * Update a KNN Doc using the POST /\/_update/\. Only the vector field will be updated. + */ + protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("doc"); + String parent = ParentChildHelper.getParentField(fieldName); + if (parent != null) { + builder.startObject(parent).field(fieldName, vector).endObject(); + } else { + builder.field(fieldName, vector); + } + builder.endObject().endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void updateKnnDocByQuery(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request("POST", "/" + index + "/_update_by_query?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .startObject("script") + .field("source", "ctx._source." + fieldName + " = params.newValue") + .field("lang", "painless") + .startObject("params") + .field("newValue", vector) + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void deleteKnnDocByQuery(String index, String docId) throws IOException { + // Put KNN mapping + Request request = new Request("POST", "/" + index + "/_delete_by_query?refresh"); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("term") + .field("id", docId) + .endObject() + .endObject() + .endObject(); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Update a KNN Doc with a new vector for the given fieldName + */ + protected void setDocToEmpty(String index, String docId) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().endObject(); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Delete Knn Doc */ @@ -787,6 +896,7 @@ protected void deleteKnnDoc(String index, String docId) throws IOException { */ protected Map getKnnDoc(final String index, final String docId) throws Exception { final Request request = new Request("GET", "/" + index + "/_doc/" + docId); + request.addParameter("ignore", "404"); final Response response = client().performRequest(request); final Map responseMap = createParser( @@ -795,8 +905,8 @@ protected Map getKnnDoc(final String index, final String docId) ).map(); assertNotNull(responseMap); - assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); - assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); + // assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); + // assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE)); final Map docMap = (Map) responseMap.get(DOCUMENT_FIELD_SOURCE); @@ -819,6 +929,22 @@ protected void updateClusterSettings(String settingKey, Object value) throws Exc assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void reindex(String source, Object destination) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("source") + .field("index", source) + .endObject() + .startObject("dest") + .field("index", destination) + .endObject() + .endObject(); + Request request = new Request("POST", "_reindex"); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Return default index settings for index creation */ @@ -896,6 +1022,21 @@ protected Response executeKnnStatRequest(List nodeIds, List stat return response; } + protected int indexSizeInBytes(String indexName) throws IOException { + Request request = new Request("GET", indexName + "/_stats" + "/store"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String responseBody = EntityUtils.toString(response.getEntity()); + + @SuppressWarnings("unchecked") + Integer sizeInBytes = (Integer) ((Map) ((Map) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("_all")).get("primaries")).get("store")).get("size_in_bytes"); + + return sizeInBytes; + } + @SneakyThrows protected void doKnnWarmup(List indices) { Response response = knnWarmup(indices); @@ -1243,15 +1384,205 @@ public Map xContentBuilderToMap(XContentBuilder xContentBuilder) } public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + // TODO: Do better on this one + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); for (int i = 0; i < numVectors; i++) { - float[] vector = new float[dimension]; - for (int j = 0; j < dimension; j++) { - vector[j] = randomFloat(); + float[] vector = vectors[i]; + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } + } + + public void bulkIngestRandomVectorsWithSkips(String indexName, String fieldName, int numVectors, int dimension, float skipProb) + throws IOException { + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector = vectors[i]; + if (random.nextFloat() > skipProb) { + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); } + } + } - addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + public void bulkIngestRandomVectorsWithSkipsAndMultFields( + String indexName, + String fieldName1, + String fieldName2, + String fieldName3, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + float[][] vectors1 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + float[][] vectors2 = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 8); + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + float[] vector1 = vectors1[i]; + float[] vector2 = vectors2[i]; + + boolean includeFieldOne = random.nextFloat() > skipProb; + boolean includeFieldTwo = random.nextFloat() > skipProb; + boolean includeFieldThree = random.nextFloat() > skipProb; + + if (includeFieldOne || includeFieldTwo || includeFieldThree) { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + if (includeFieldOne) { + xContentBuilder.field(fieldName1, vector1); + } + if (includeFieldTwo) { + xContentBuilder.field(fieldName2, vector2); + } + if (includeFieldThree) { + xContentBuilder.field(fieldName3, "test-test"); + } + xContentBuilder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), xContentBuilder.toString()); + } else { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } } + } + + @SneakyThrows + public void bulkIngestRandomVectorsMultiFieldsWithSkips( + String indexName, + List vectorFields, + List textFields, + int numVectors, + int dimension, + float skipProb + ) { + List vectors = new ArrayList<>(); + int seed = 1; + for (String ignored : vectorFields) { + vectors.add(TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, seed++)); + } + + Random random = new Random(); + random.setSeed(2); + for (int i = 0; i < numVectors; i++) { + + List includeVectorFields = new ArrayList<>(); + for (String ignored : vectorFields) { + includeVectorFields.add(random.nextFloat() > skipProb); + } + List includeTextFields = new ArrayList<>(); + for (String ignored : textFields) { + includeTextFields.add(random.nextFloat() > skipProb); + } + + // If all are skipped, just add a random field + if (includeVectorFields.stream().allMatch((t) -> !t) && includeTextFields.stream().allMatch((t) -> !t)) { + addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); + } else { + Map source = new HashMap<>(); + for (int j = 0; j < includeVectorFields.size(); j++) { + if (includeVectorFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(vectorFields.get(j)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], vectors.get(j)[i]); + } + } + for (int j = 0; j < includeTextFields.size(); j++) { + if (includeTextFields.get(j)) { + String[] fields = ParentChildHelper.splitPath(textFields.get(j)); + Map currentMap = source; + for (int k = 0; k < fields.length - 1; k++) { + String field = fields[k]; + Object value = currentMap.get(field); + currentMap = (Map) currentMap.computeIfAbsent(field, t -> new HashMap<>()); + } + currentMap.put(fields[fields.length - 1], "test-test"); + } + } + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + mapToBuilder(builder, source); + builder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + } + + @SneakyThrows + void mapToBuilder(XContentBuilder xContentBuilder, Map source) { + for (Map.Entry entry : source.entrySet()) { + if (entry.getValue() instanceof Map) { + xContentBuilder.startObject(entry.getKey()); + mapToBuilder(xContentBuilder, (Map) entry.getValue()); + xContentBuilder.endObject(); + } else { + xContentBuilder.field(entry.getKey(), entry.getValue()); + } + } + } + + public void bulkIngestRandomVectorsWithSkipsAndNested( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numVectors, + int dimension, + float skipProb + ) throws IOException { + bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + indexName, + nestedFieldName, + nestedNumericPath, + numVectors, + dimension, + skipProb, + 1 + ); + } + + public void bulkIngestRandomVectorsWithSkipsAndNestedMultiDoc( + String indexName, + String nestedFieldName, + String nestedNumericPath, + int numDocs, + int dimension, + float skipProb, + int maxDoc + ) throws IOException { + Random random = new Random(); + random.setSeed(2); + float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numDocs * maxDoc, dimension, 1); + for (int i = 0; i < numDocs; i++) { + int nestedDocs = random.nextInt(maxDoc) + 1; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startArray(ParentChildHelper.getParentField(nestedFieldName)); + for (int j = 0; j < nestedDocs; j++) { + builder.startObject(); + if (random.nextFloat() > skipProb) { + builder.field(ParentChildHelper.getChildField(nestedFieldName), vectors[i + j]); + } else { + builder.field(ParentChildHelper.getChildField(nestedNumericPath), 1); + } + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + addKnnDoc(indexName, String.valueOf(i + 1), builder.toString()); + } + } + public float[] randomFloatVector(int dimension) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomFloat(); + } + return vector; } /**