From 5bacaa71bdf9f485c295cf5b2f257a70772a1939 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Tue, 20 Aug 2024 19:01:21 -0700 Subject: [PATCH] Adds integration tests for NativeEngineKnnVectorsFormat code path Signed-off-by: Tejas Shah --- .gitignore | 5 + .../common/featureflags/KNNFeatureFlags.java | 31 +++++- .../org/opensearch/knn/index/KNNSettings.java | 31 ------ .../codec/BasePerFieldKnnVectorsFormat.java | 27 ++++- .../mapper/KNNVectorFieldMapperUtil.java | 3 +- .../knn/index/mapper/KNNVectorFieldType.java | 2 + .../knn/index/mapper/ModelFieldMapper.java | 7 +- .../knn/common/FieldInfoExtractorTests.java | 23 ++-- .../knn/index/FaissHNSWFlatE2EIT.java | 65 +++++++---- .../org/opensearch/knn/index/FaissIT.java | 23 +++- .../org/opensearch/knn/index/NmslibIT.java | 22 ++++ .../BasePerFieldKnnVectorsFormatTests.java | 104 ++++++++++++++++++ .../mapper/KNNVectorFieldMapperUtilTests.java | 10 +- 13 files changed, 275 insertions(+), 78 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormatTests.java diff --git a/.gitignore b/.gitignore index 7ff7640569..34f7ff1543 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,11 @@ jni/lib/ jni/jni_test* jni/googletest* jni/cmake/*.cmake-e +jni/.cmake +jni/.idea +jni/build.ninja +jni/.ninja_deps +jni/.ninja_log benchmarks/perf-tool/okpt/output benchmarks/perf-tool/okpt/dev diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java index 9b4a5ba7eb..3398681afa 100644 --- a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -28,6 +28,13 @@ public class KNNFeatureFlags { private static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.feature.query.rewrite.enabled"; private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = false; + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; + @VisibleForTesting public static final Setting KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting( KNN_LAUNCH_QUERY_REWRITE_ENABLED, @@ -36,8 +43,30 @@ public class KNNFeatureFlags { Dynamic ); + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final Setting KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting( + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, + false, + NodeScope, + Dynamic + ); + + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static boolean getIsLuceneVectorFormatEnabled() { + return KNNSettings.state().getSettingValue(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED); + } + public static List> getFeatureFlags() { - return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING).collect(Collectors.toUnmodifiableList()); + return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING, KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING) + .collect(Collectors.toUnmodifiableList()); } public static boolean isKnnQueryRewriteEnabled() { diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 73f43d3d1e..a70a17d858 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -83,12 +83,6 @@ public class KNNSettings { public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled"; - /** - * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the - * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added - * for native engines. - */ - public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit"; public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; @@ -269,17 +263,6 @@ public class KNNSettings { NodeScope ); - /** - * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the - * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added - * for native engines. - */ - public static final Setting KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting( - KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, - false, - NodeScope - ); - /* * Quantization state cache settings */ @@ -449,10 +432,6 @@ private Setting getSetting(String key) { return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; } - if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) { - return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING; - } - if (QUANTIZATION_STATE_CACHE_SIZE_LIMIT.equals(key)) { return QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING; } @@ -480,7 +459,6 @@ public List> getSettings() { ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, KNN_FAISS_AVX2_DISABLED_SETTING, KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, - KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING ); @@ -528,15 +506,6 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - /** - * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the - * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added - * for native engines. - */ - public static boolean getIsLuceneVectorFormatEnabled() { - return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED); - } - public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 8beced605d..fe149c5917 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -19,6 +19,7 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.indices.ModelMetadata; import java.util.Map; import java.util.Optional; @@ -28,6 +29,8 @@ import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.index.mapper.ModelFieldMapper.getKNNMethodContextFromModelMetadata; +import static org.opensearch.knn.indices.ModelUtil.getModelMetadata; /** * Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version @@ -78,14 +81,10 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ) ).fieldType(field); - KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() - .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); - - final KNNEngine engine = knnMethodContext.getKnnEngine(); + final KNNMethodContext knnMethodContext = extractKNNMethodContext(mappedFieldType); final Map params = knnMethodContext.getMethodComponentContext().getParameters(); - if (engine == KNNEngine.LUCENE) { + if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( params, @@ -133,4 +132,20 @@ public int getMaxDimensions(String fieldName) { private boolean isKnnVectorFieldType(final String field) { return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldType; } + + private KNNMethodContext extractKNNMethodContext(final KNNVectorFieldType mappedFieldType) { + final KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + final KNNMethodContext knnMethodContext; + if (knnMappingConfig.getModelId().isPresent()) { + ModelMetadata modelMetadata = getModelMetadata(knnMappingConfig.getModelId().get()); + assert modelMetadata != null : String.format("Model ID '%s' is not " + "created.", knnMappingConfig.getModelId().get()); + knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + } else if (knnMappingConfig.getKnnMethodContext().isPresent()) { + knnMethodContext = knnMappingConfig.getKnnMethodContext().get(); + } else { + throw new IllegalArgumentException("Could not extract KNN method context from mapped field [" + mappedFieldType + "]"); + } + + return knnMethodContext; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 57a4dd062b..386ead1222 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.settings.Settings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.SpaceType; @@ -143,7 +144,7 @@ static void validateIfKNNPluginEnabled() { * @return true if vector field should use KNNVectorsFormat */ static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) { - return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled(); + return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNFeatureFlags.getIsLuceneVectorFormatEnabled(); } private static SpaceType getSpaceType(final Settings indexSettings) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 0fbc569f77..8a0f12a720 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.mapper; import lombok.Getter; +import lombok.ToString; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; @@ -29,6 +30,7 @@ /** * A KNNVector field type to represent the vector field in Opensearch */ +@ToString @Getter public class KNNVectorFieldType extends MappedFieldType { KNNMappingConfig knnMappingConfig; diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index b29466eefc..6d1f05078d 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -228,7 +228,12 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } - private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { + /** + * Extracts knnMethodContext from model metadata + * @param modelMetadata + * @return null if MethodComponentContext is empty else returns a new object of KNNMethodContext + */ + public static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (methodComponentContext == MethodComponentContext.EMPTY) { return null; diff --git a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java index 27aedd1d04..0b277ec83f 100644 --- a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java +++ b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java @@ -11,9 +11,12 @@ import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; +import java.util.Map; + import static org.mockito.Mockito.when; public class FieldInfoExtractorTests extends KNNTestCase { @@ -42,25 +45,23 @@ public void testExtractVectorDataType_whenDifferentConditions_thenSuccess() { } } - public void testExtractVectorDataType() { + public void testExtractKNNEngine() { FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); - when(fieldInfo.getAttribute("data_type")).thenReturn(VectorDataType.BINARY.getValue()); + when(fieldInfo.attributes()).thenReturn(Map.of("engine", KNNEngine.FAISS.getName())); - assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo)); - when(fieldInfo.getAttribute("data_type")).thenReturn(null); + assertEquals(KNNEngine.FAISS, FieldInfoExtractor.extractKNNEngine(fieldInfo)); + when(fieldInfo.getAttribute("engine")).thenReturn(null); - when(fieldInfo.getAttribute("model_id")).thenReturn(MODEL_ID); + when(fieldInfo.attributes()).thenReturn(Map.of("model_id", MODEL_ID)); try (MockedStatic modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)) { ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class); modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata); - when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); - assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo)); - when(modelMetadata.getVectorDataType()).thenReturn(null); - when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); + assertEquals(KNNEngine.FAISS, FieldInfoExtractor.extractKNNEngine(fieldInfo)); } - when(fieldInfo.getAttribute("model_id")).thenReturn(null); - assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + when(fieldInfo.attributes()).thenReturn(Map.of("blah", "blah")); + assertEquals(KNNEngine.NMSLIB, FieldInfoExtractor.extractKNNEngine(fieldInfo)); } } diff --git a/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java index 828fa656a0..f90f0417a3 100644 --- a/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java @@ -17,6 +17,7 @@ import lombok.AllArgsConstructor; import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentFactory; @@ -25,8 +26,8 @@ import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.io.IOException; @@ -47,6 +48,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED; @AllArgsConstructor public class FaissHNSWFlatE2EIT extends KNNRestTestCase { @@ -55,6 +57,7 @@ public class FaissHNSWFlatE2EIT extends KNNRestTestCase { private int k; private Map methodParameters; private boolean deleteRandomDocs; + private Boolean knnUseLuceneVectorFormat; static TestUtils.TestData testData; @@ -70,14 +73,26 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s") + @Before + public void init() throws Exception { + super.setUp(); + updateClusterSettings(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, knnUseLuceneVectorFormat); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s, knnUseLuceneVectorFormat:%5$s") public static Collection parameters() { return Arrays.asList( $$( - $("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false), - $("Valid k, efsearch absent", 10, null, false), - $("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true), - $("Has delete docs", 10, null, true) + $("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false, false), + $("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true, false), + $( + "Valid k, valid efSearch efSearch value, knnVectors format code path", + 10, + Map.of(METHOD_PARAMETER_EF_SEARCH, 300), + false, + true + ), + $("Has delete docs, ef_search, knnVectors format code path", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true, true) ) ); } @@ -152,6 +167,29 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { } // Test search queries + // Without method parameters + search(fieldName, indexName, spaceType, null); + // With method parameters + search(fieldName, indexName, spaceType, methodParameters); + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + + @SneakyThrows + private void search(final String fieldName, final String indexName, final SpaceType spaceType, final Map methodParameters) { + for (int i = 0; i < testData.queries.length; i++) { final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder() .fieldName(fieldName) @@ -159,6 +197,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { .k(k) .methodParameters(methodParameters) .build(); + Response response = searchKNNIndex(indexName, queryBuilder, k); String responseBody = EntityUtils.toString(response.getEntity()); List knnResults = parseSearchResponse(responseBody, fieldName); @@ -174,19 +213,5 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { ); } } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 2df1d8a608..2b56106a3e 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -11,12 +11,15 @@ package org.opensearch.knn.index; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; @@ -38,6 +41,7 @@ import java.net.URL; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Locale; import java.util.Map; @@ -45,6 +49,8 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; @@ -76,7 +82,9 @@ import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED; +@AllArgsConstructor public class FaissIT extends KNNRestTestCase { private static final String DOC_ID_1 = "doc1"; private static final String DOC_ID_2 = "doc2"; @@ -95,6 +103,19 @@ public class FaissIT extends KNNRestTestCase { static TestUtils.TestData testData; + private Boolean knnUseLuceneVectorFormat; + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList($$($(false), $(true))); + } + + @Before + public void init() throws Exception { + super.setUp(); + updateClusterSettings(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, knnUseLuceneVectorFormat); + } + @BeforeClass public static void setUpClass() throws IOException { if (FaissIT.class.getClassLoader() == null) { @@ -1090,7 +1111,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( /** * This test confirms that sharing index state for IVFPQ-l2 indices functions properly. The main functionality that * needs to be confirmed is that once an index gets deleted, it will not cause a failure for the non-deleted index. - * + *

* The workflow will be: * 1. Create a model * 2. Create two indices index from the model diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index 0d8dd9b12c..b51dd52628 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -11,10 +11,13 @@ package org.opensearch.knn.index; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.core.xcontent.XContentBuilder; @@ -31,16 +34,35 @@ import java.io.IOException; import java.net.URL; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.TreeMap; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED; +@AllArgsConstructor public class NmslibIT extends KNNRestTestCase { static TestUtils.TestData testData; + private Boolean knnUseLuceneVectorFormat; + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList($$($(false), $(true))); + } + + @Before + public void init() throws Exception { + super.setUp(); + updateClusterSettings(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, knnUseLuceneVectorFormat); + } + @BeforeClass public static void setUpClass() throws IOException { if (NmslibIT.class.getClassLoader() == null) { diff --git a/src/test/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormatTests.java new file mode 100644 index 0000000000..1e8d6b8962 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormatTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; +import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.mapper.ModelFieldMapper; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; +import java.util.function.Function; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class BasePerFieldKnnVectorsFormatTests extends OpenSearchTestCase { + + private static final String FIELD = "field"; + private static final String MODEL_ID = "model_id"; + + @Mock + private MapperService mapperService; + @Mock + private Function vectorsFormatSupplier; + @Mock + private KNNMappingConfig knnMappingConfig; + @Mock + private KNNMethodContext knnMethodContext; + @Mock + private MethodComponentContext methodComponentContext; + + private BasePerFieldKnnVectorsFormat basePerFieldKnnVectorsFormat; + + @Before + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + basePerFieldKnnVectorsFormat = new BasePerFieldKnnVectorsFormat( + Optional.of(mapperService), + 10, + 10, + Lucene99HnswVectorsFormat::new, + vectorsFormatSupplier + ) { + }; + } + + public void testGetKNNVectorsFormatForField() { + MappedFieldType mappedFieldType = mock(MappedFieldType.class); + when(mapperService.fieldType(FIELD)).thenReturn(mappedFieldType); + + KnnVectorsFormat knnVectorsFormat = basePerFieldKnnVectorsFormat.getKnnVectorsFormatForField(FIELD); + assertEquals(Lucene99HnswVectorsFormat.class, knnVectorsFormat.getClass()); + + KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(knnMappingConfig); + when(knnMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(mapperService.fieldType(FIELD)).thenReturn(knnVectorFieldType); + + KnnVectorsFormat expected = new Lucene99HnswVectorsFormat(10, 10); + when(vectorsFormatSupplier.apply(new KNNVectorsFormatParams(null, 10, 10))).thenReturn(expected); + assertEquals(NativeEngines990KnnVectorsFormat.class, basePerFieldKnnVectorsFormat.getKnnVectorsFormatForField(FIELD).getClass()); + } + + public void testGetKNNVectorsFormatForFieldWithModel() { + ModelMetadata metadata = mock(ModelMetadata.class); + try ( + MockedStatic modelUtilMock = mockStatic(ModelUtil.class); + MockedStatic modelFieldMapperMock = mockStatic(ModelFieldMapper.class) + ) { + KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(knnMappingConfig); + when(knnMappingConfig.getModelId()).thenReturn(Optional.of(MODEL_ID)); + modelUtilMock.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(metadata); + modelFieldMapperMock.when(() -> ModelFieldMapper.getKNNMethodContextFromModelMetadata(metadata)).thenReturn(knnMethodContext); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(mapperService.fieldType(FIELD)).thenReturn(knnVectorFieldType); + + assertEquals( + NativeEngines990KnnVectorsFormat.class, + basePerFieldKnnVectorsFormat.getKnnVectorsFormatForField(FIELD).getClass() + ); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 5ebe3281ae..333db377e3 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -18,7 +18,7 @@ import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; @@ -76,15 +76,13 @@ public void testGetExpectedVectorLengthSuccess() { } public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { - final KNNSettings knnSettings = mock(KNNSettings.class); - final MockedStatic mockedStatic = Mockito.mockStatic(KNNSettings.class); - mockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + final MockedStatic mockedStatic = Mockito.mockStatic(KNNFeatureFlags.class); - mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(false); + mockedStatic.when(KNNFeatureFlags::getIsLuceneVectorFormatEnabled).thenReturn(false); Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_16_0)); Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); - mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(true); + mockedStatic.when(KNNFeatureFlags::getIsLuceneVectorFormatEnabled).thenReturn(true); Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0)); Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); // making sure to close the static mock to ensure that for tests running on this thread are not impacted by