From 49fb96e5e700c06eede5905f86d7705fa99e466f Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sun, 11 Aug 2024 11:47:56 -0700 Subject: [PATCH] Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation Signed-off-by: Navneet Verma --- CHANGELOG.md | 3 +- .../org/opensearch/knn/index/KNNSettings.java | 33 ++++- .../index/mapper/FlatVectorFieldMapper.java | 4 + .../index/mapper/KNNVectorFieldMapper.java | 45 +++--- .../mapper/KNNVectorFieldMapperUtil.java | 16 +++ .../knn/index/mapper/LuceneFieldMapper.java | 8 +- .../knn/index/mapper/MethodFieldMapper.java | 20 +++ .../knn/index/mapper/ModelFieldMapper.java | 20 ++- .../knn/index/codec/KNNCodecTestCase.java | 10 +- .../mapper/KNNVectorFieldMapperTests.java | 128 ++++++++++++++++-- .../mapper/KNNVectorFieldMapperUtilTests.java | 22 +++ 11 files changed, 268 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81c90802e..f1dc5b14d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Infrastructure ### Documentation ### Maintenance @@ -32,4 +33,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) * Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) -* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) \ No newline at end of file +* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 33c7ff410..4ced38b38 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -82,6 +82,12 @@ public class KNNSettings { public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled"; + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; /** * Default setting values @@ -255,6 +261,17 @@ public class KNNSettings { NodeScope ); + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final Setting KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting( + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, + false, + NodeScope + ); + /** * Dynamic settings */ @@ -379,6 +396,10 @@ private Setting getSetting(String key) { return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; } + if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) { + return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -397,7 +418,8 @@ public List> getSettings() { MODEL_CACHE_SIZE_LIMIT_SETTING, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, KNN_FAISS_AVX2_DISABLED_SETTING, - KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING + KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -443,6 +465,15 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static boolean getIsLuceneVectorFormatEnabled() { + return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index fffff30f4..146b5132f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.knn.index.VectorDataType; @@ -57,8 +58,11 @@ private FlatVectorFieldMapper( Version indexCreatedVersion ) { super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. + this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.setDocValuesType(DocValuesType.BINARY); this.fieldType.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 40eaa12ae..5d4d3ca58 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -16,7 +16,8 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.IndexOptions; import org.opensearch.Version; import org.opensearch.common.Explicit; @@ -456,6 +457,7 @@ public Mapper.Builder parse(String name, Map node, ParserCont protected boolean hasDocValues; protected VectorDataType vectorDataType; protected ModelDao modelDao; + protected boolean useLuceneBasedVectorField; // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the // Builder for serialization. So, we need to store it here. This is mainly to ensure that the legacy field mapper @@ -497,16 +499,29 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, fieldType().getKnnMappingConfig().getDimension(), fieldType().getVectorDataType()); } + private Field createVectorField(float[] vectorValue) { + if (useLuceneBasedVectorField) { + return new KnnFloatVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + + private Field createVectorField(byte[] vectorValue) { + if (useLuceneBasedVectorField) { + return new KnnByteVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + /** * Function returns a list of fields to be indexed when the vector is float type. * * @param array array of floats - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(array)); if (this.stored) { fields.add(createStoredFieldForFloatVector(name(), array)); } @@ -517,12 +532,11 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy * Function returns a list of fields to be indexed when the vector is byte type. * * @param array array of bytes - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] array) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(array)); if (this.stored) { fields.add(createStoredFieldForByteVector(name(), array)); } @@ -561,24 +575,14 @@ protected void validatePreparse() { protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { validatePreparse(); - if (VectorDataType.BINARY == vectorDataType) { - Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); - - if (bytesArrayOptional.isEmpty()) { - return; - } - final byte[] array = bytesArrayOptional.get(); - getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); - } else if (VectorDataType.BYTE == vectorDataType) { + if (VectorDataType.BINARY == vectorDataType || VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); - if (bytesArrayOptional.isEmpty()) { return; } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -587,7 +591,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForFloatVector(array, fieldType)); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) @@ -714,7 +718,6 @@ public static class Defaults { static { FIELD_TYPE.setTokenized(false); FIELD_TYPE.setIndexOptions(IndexOptions.NONE); - FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type FIELD_TYPE.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 0caaf80ab..9cd6bb467 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -238,6 +238,22 @@ static void validateIfKNNPluginEnabled() { } } + /** + * Prerequisite: Index should a knn index which is validated via index settings index.knn setting. This function + * assumes that caller has already validated that index is a KNN index. + * We will use LuceneKNNVectorsFormat when these below condition satisfy: + *
    + *
  1. Index is created with Version of opensearch >= 2.17
  2. + *
  3. Cluster setting is enabled to use Lucene KNNVectors format. This condition is temporary condition and will be + * removed before release.
  4. + *
+ * @param indexCreatedVersion {@link Version} + * @return true if vector field should use KNNVectorsFormat + */ + static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) { + return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled(); + } + private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); if (spaceType == null) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 665c35f6e..7c3d942b6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -17,7 +17,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.Version; import org.opensearch.common.Explicit; @@ -112,9 +112,9 @@ private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final Create } @Override - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fieldsToBeAdded = new ArrayList<>(); - fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType)); + fieldsToBeAdded.add(new KnnFloatVectorField(name(), array, fieldType)); if (hasDocValues && vectorFieldType != null) { fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); @@ -127,7 +127,7 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy } @Override - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] array) { final List fieldsToBeAdded = new ArrayList<>(); fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType)); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 7a69c941b..cc2c43386 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,9 +6,12 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -99,6 +102,7 @@ private MethodFieldMapper( indexVerision, originalKNNMethodContext ); + this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); @@ -118,6 +122,22 @@ private MethodFieldMapper( throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); } + if (useLuceneBasedVectorField) { + int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY + ? annConfig.getDimension() / 8 + : annConfig.getDimension(); + final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT + ? VectorEncoding.FLOAT32 + : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } + this.fieldType.freeze(); initValidatorsAndProcessors(knnMethodContext); knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index a21a01a5d..6c7e45e7e 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -6,9 +6,12 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; @@ -102,7 +105,7 @@ private ModelFieldMapper( this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); - this.fieldType.freeze(); + this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion); } @Override @@ -193,6 +196,21 @@ private void initPerDimensionProcessor() { protected void parseCreateField(ParseContext context) throws IOException { validatePreparse(); ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + if (useLuceneBasedVectorField) { + int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY + ? modelMetadata.getDimension() + : modelMetadata.getDimension() / 8; + final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT + ? VectorEncoding.FLOAT32 + : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 00cc2b167..bf2c33bf9 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; @@ -89,8 +91,6 @@ * Test used for testing Codecs */ public class KNNCodecTestCase extends KNNTestCase { - - private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); private static final FieldType sampleFieldType; static { KNNMethodContext knnMethodContext = new KNNMethodContext( @@ -109,6 +109,7 @@ public class KNNCodecTestCase extends KNNTestCase { } sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + sampleFieldType.setDocValuesType(DocValuesType.BINARY); sampleFieldType.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, knnMethodContext.getKnnEngine().getName()); sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); @@ -259,6 +260,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio iwc.setCodec(codec); FieldType fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + fieldType.setDocValuesType(DocValuesType.BINARY); fieldType.putAttribute(KNNConstants.MODEL_ID, modelId); fieldType.freeze(); @@ -356,9 +358,9 @@ public void testKnnVectorIndex( /** * Add doc with field "test_vector_one" */ - final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); + final FieldType luceneFieldType = KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); float[] array = { 1.0f, 3.0f, 4.0f }; - KnnVectorField vectorField = new KnnVectorField(FIELD_NAME_ONE, array, luceneFieldType); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME_ONE, array, luceneFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); doc.add(vectorField); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index f06ff7935..e1d842112 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -7,11 +7,14 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; +import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.Explicit; @@ -27,14 +30,14 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -79,6 +82,7 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; +@Log4j2 public class KNNVectorFieldMapperTests extends KNNTestCase { private static final String TEST_FIELD_NAME = "test-field-name"; @@ -739,6 +743,112 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } + @SneakyThrows + public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() { + MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class); + for (VectorDataType dataType : VectorDataType.values()) { + log.info("Vector Data Type is : {}", dataType); + int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION; + final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + methodComponentContext.setIndexVersion(CURRENT); + SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT; + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + MethodFieldMapper methodFieldMapper = Mockito.spy( + MethodFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + dimension, + knnMethodContext, + knnMethodContext, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + CURRENT + ) + ); + + if (dataType == VectorDataType.BINARY) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType); + } else if (dataType == VectorDataType.BYTE) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType); + } else { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } + + methodFieldMapper.parseCreateField(parseContext, dimension, dataType); + + List fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field1 = fields.get(0); + if (dataType == VectorDataType.FLOAT) { + assertTrue(field1 instanceof KnnFloatVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32); + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE); + } + + assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION); + assertEquals( + field1.fieldType().vectorSimilarityFunction(), + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); + + document = new ParseContext.Document(); + contentPath = new ContentPath(); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + methodFieldMapper = Mockito.spy( + MethodFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + dimension, + knnMethodContext, + knnMethodContext, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + CURRENT + ) + ); + + if (dataType == VectorDataType.FLOAT) { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } else { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType); + } + + methodFieldMapper.parseCreateField(parseContext, dimension, dataType); + fields = document.getFields(); + assertEquals(1, fields.size()); + field1 = fields.get(0); + assertTrue(field1 instanceof VectorField); + } + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + utilMockedStatic.close(); + } + @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField @@ -765,22 +875,22 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { doNothing().when(luceneFieldMapper).validatePreparse(); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); - // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField + // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnFloatVectorField List fields = document.getFields(); assertEquals(2, fields.size()); IndexableField field1 = fields.get(0); IndexableField field2 = fields.get(1); VectorField vectorField; - KnnVectorField knnVectorField; + KnnFloatVectorField knnVectorField; if (field1 instanceof VectorField) { assertTrue(field2 instanceof KnnVectorField); vectorField = (VectorField) field1; - knnVectorField = (KnnVectorField) field2; + knnVectorField = (KnnFloatVectorField) field2; } else { - assertTrue(field1 instanceof KnnVectorField); + assertTrue(field1 instanceof KnnFloatVectorField); assertTrue(field2 instanceof VectorField); - knnVectorField = (KnnVectorField) field1; + knnVectorField = (KnnFloatVectorField) field1; vectorField = (VectorField) field2; } @@ -821,8 +931,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { fields = document.getFields(); assertEquals(1, fields.size()); IndexableField field = fields.get(0); - assertTrue(field instanceof KnnVectorField); - knnVectorField = (KnnVectorField) field; + assertTrue(field instanceof KnnFloatVectorField); + knnVectorField = (KnnFloatVectorField) field; assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 8ace5557e..a80110181 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,8 +13,13 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.util.BytesRef; +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -105,6 +110,23 @@ public void testValidateVectorDataType_whenFloat_thenValid() { validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); } + public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { + final KNNSettings knnSettings = mock(KNNSettings.class); + final MockedStatic mockedStatic = Mockito.mockStatic(KNNSettings.class); + mockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(false); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_16_0)); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(true); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0)); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + mockedStatic.close(); + } + private void validateValidateVectorDataType( final KNNEngine knnEngine, final String methodName,