Skip to content

Commit

Permalink
Adds integration tests for NativeEngineKnnVectorsFormat code path
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Aug 27, 2024
1 parent 33eb45b commit 5bacaa7
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 78 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ jni/lib/
jni/jni_test*
jni/googletest*
jni/cmake/*.cmake-e
jni/.cmake
jni/.idea
jni/build.ninja
jni/.ninja_deps
jni/.ninja_log

benchmarks/perf-tool/okpt/output
benchmarks/perf-tool/okpt/dev
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ public class KNNFeatureFlags {
private static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.feature.query.rewrite.enabled";
private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = false;

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled";

@VisibleForTesting
public static final Setting<Boolean> KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting(
KNN_LAUNCH_QUERY_REWRITE_ENABLED,
Expand All @@ -36,8 +43,30 @@ public class KNNFeatureFlags {
Dynamic
);

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final Setting<Boolean> KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting(
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED,
false,
NodeScope,
Dynamic
);

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static boolean getIsLuceneVectorFormatEnabled() {
return KNNSettings.state().getSettingValue(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED);
}

public static List<Setting<?>> getFeatureFlags() {
return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING).collect(Collectors.toUnmodifiableList());
return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING, KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING)
.collect(Collectors.toUnmodifiableList());
}

public static boolean isKnnQueryRewriteEnabled() {
Expand Down
31 changes: 0 additions & 31 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ public class KNNSettings {
public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit";
public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold";
public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled";
/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled";
public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit";
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";

Expand Down Expand Up @@ -269,17 +263,6 @@ public class KNNSettings {
NodeScope
);

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final Setting<Boolean> KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting(
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED,
false,
NodeScope
);

/*
* Quantization state cache settings
*/
Expand Down Expand Up @@ -449,10 +432,6 @@ private Setting<?> getSetting(String key) {
return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING;
}

if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) {
return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING;
}

if (QUANTIZATION_STATE_CACHE_SIZE_LIMIT.equals(key)) {
return QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING;
}
Expand Down Expand Up @@ -480,7 +459,6 @@ public List<Setting<?>> getSettings() {
ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING,
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING,
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING
);
Expand Down Expand Up @@ -528,15 +506,6 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static boolean getIsLuceneVectorFormatEnabled() {
return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.indices.ModelMetadata;

import java.util.Map;
import java.util.Optional;
Expand All @@ -28,6 +29,8 @@
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.index.mapper.ModelFieldMapper.getKNNMethodContextFromModelMetadata;
import static org.opensearch.knn.indices.ModelUtil.getModelMetadata;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
Expand Down Expand Up @@ -78,14 +81,10 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);

KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext()
.orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));

final KNNEngine engine = knnMethodContext.getKnnEngine();
final KNNMethodContext knnMethodContext = extractKNNMethodContext(mappedFieldType);
final Map<String, Object> params = knnMethodContext.getMethodComponentContext().getParameters();

if (engine == KNNEngine.LUCENE) {
if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) {
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
Expand Down Expand Up @@ -133,4 +132,20 @@ public int getMaxDimensions(String fieldName) {
private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldType;
}

private KNNMethodContext extractKNNMethodContext(final KNNVectorFieldType mappedFieldType) {
final KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
final KNNMethodContext knnMethodContext;
if (knnMappingConfig.getModelId().isPresent()) {
ModelMetadata modelMetadata = getModelMetadata(knnMappingConfig.getModelId().get());
assert modelMetadata != null : String.format("Model ID '%s' is not " + "created.", knnMappingConfig.getModelId().get());
knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
} else if (knnMappingConfig.getKnnMethodContext().isPresent()) {
knnMethodContext = knnMappingConfig.getKnnMethodContext().get();
} else {
throw new IllegalArgumentException("Could not extract KNN method context from mapped field [" + mappedFieldType + "]");
}

return knnMethodContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KnnCircuitBreakerException;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -143,7 +144,7 @@ static void validateIfKNNPluginEnabled() {
* @return true if vector field should use KNNVectorsFormat
*/
static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) {
return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled();
return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNFeatureFlags.getIsLuceneVectorFormatEnabled();
}

private static SpaceType getSpaceType(final Settings indexSettings) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import lombok.ToString;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
Expand All @@ -29,6 +30,7 @@
/**
* A KNNVector field type to represent the vector field in Opensearch
*/
@ToString
@Getter
public class KNNVectorFieldType extends MappedFieldType {
KNNMappingConfig knnMappingConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
/**
* Extracts knnMethodContext from model metadata
* @param modelMetadata
* @return null if MethodComponentContext is empty else returns a new object of KNNMethodContext
*/
public static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (methodComponentContext == MethodComponentContext.EMPTY) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.Map;

import static org.mockito.Mockito.when;

public class FieldInfoExtractorTests extends KNNTestCase {
Expand Down Expand Up @@ -42,25 +45,23 @@ public void testExtractVectorDataType_whenDifferentConditions_thenSuccess() {
}
}

public void testExtractVectorDataType() {
public void testExtractKNNEngine() {
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
when(fieldInfo.getAttribute("data_type")).thenReturn(VectorDataType.BINARY.getValue());
when(fieldInfo.attributes()).thenReturn(Map.of("engine", KNNEngine.FAISS.getName()));

assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo));
when(fieldInfo.getAttribute("data_type")).thenReturn(null);
assertEquals(KNNEngine.FAISS, FieldInfoExtractor.extractKNNEngine(fieldInfo));
when(fieldInfo.getAttribute("engine")).thenReturn(null);

when(fieldInfo.getAttribute("model_id")).thenReturn(MODEL_ID);
when(fieldInfo.attributes()).thenReturn(Map.of("model_id", MODEL_ID));
try (MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)) {
ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class);
modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata);
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);

assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo));
when(modelMetadata.getVectorDataType()).thenReturn(null);
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT);
assertEquals(KNNEngine.FAISS, FieldInfoExtractor.extractKNNEngine(fieldInfo));
}

when(fieldInfo.getAttribute("model_id")).thenReturn(null);
assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo));
when(fieldInfo.attributes()).thenReturn(Map.of("blah", "blah"));
assertEquals(KNNEngine.NMSLIB, FieldInfoExtractor.extractKNNEngine(fieldInfo));
}
}
65 changes: 45 additions & 20 deletions src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.Before;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -25,8 +26,8 @@
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

import java.io.IOException;
Expand All @@ -47,6 +48,7 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED;

@AllArgsConstructor
public class FaissHNSWFlatE2EIT extends KNNRestTestCase {
Expand All @@ -55,6 +57,7 @@ public class FaissHNSWFlatE2EIT extends KNNRestTestCase {
private int k;
private Map<String, ?> methodParameters;
private boolean deleteRandomDocs;
private Boolean knnUseLuceneVectorFormat;

static TestUtils.TestData testData;

Expand All @@ -70,14 +73,26 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
}

@ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s")
@Before
public void init() throws Exception {
super.setUp();
updateClusterSettings(KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, knnUseLuceneVectorFormat);
}

@ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s, knnUseLuceneVectorFormat:%5$s")
public static Collection<Object[]> parameters() {
return Arrays.asList(
$$(
$("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false),
$("Valid k, efsearch absent", 10, null, false),
$("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true),
$("Has delete docs", 10, null, true)
$("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false, false),
$("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true, false),
$(
"Valid k, valid efSearch efSearch value, knnVectors format code path",
10,
Map.of(METHOD_PARAMETER_EF_SEARCH, 300),
false,
true
),
$("Has delete docs, ef_search, knnVectors format code path", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true, true)
)
);
}
Expand Down Expand Up @@ -152,13 +167,37 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
}

// Test search queries
// Without method parameters
search(fieldName, indexName, spaceType, null);
// With method parameters
search(fieldName, indexName, spaceType, methodParameters);

// Delete index
deleteKNNIndex(indexName);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}
Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

@SneakyThrows
private void search(final String fieldName, final String indexName, final SpaceType spaceType, final Map<String, ?> methodParameters) {

for (int i = 0; i < testData.queries.length; i++) {
final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(testData.queries[i])
.k(k)
.methodParameters(methodParameters)
.build();

Response response = searchKNNIndex(indexName, queryBuilder, k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
Expand All @@ -174,19 +213,5 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
);
}
}

// Delete index
deleteKNNIndex(indexName);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}
Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}
}
Loading

0 comments on commit 5bacaa7

Please sign in to comment.