diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java index e8bf476253..174c4f0a5a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java @@ -135,7 +135,7 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R List efConstructionValues = ImmutableList.of(16, 32, 64, 128); List efSearchValues = ImmutableList.of(16, 32, 64, 128); - int dimension = 2; + int dimension = 128; // Create an index /** @@ -198,16 +198,35 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R createKnnIndex(testIndex, mapping); assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(testIndex))); - Float[] vector1 = { -65523.76f, 65504.2f }; - Float[] vector2 = { -270.85f, 65514.2f }; - Float[] vector3 = { -150.9f, 65504.0f }; - Float[] vector4 = { -20.89f, 100000000.0f }; + + Float[] vector1 = new Float[dimension]; + Float[] vector2 = new Float[dimension]; + Float[] vector3 = new Float[dimension]; + Float[] vector4 = new Float[dimension]; + float[] queryVector = new float[dimension]; + int halfDimension = dimension / 2; + + for (int i = 0; i < dimension; i++) { + if (i < halfDimension) { + vector1[i] = -65523.76f; + vector2[i] = -270.85f; + vector3[i] = -150.9f; + vector4[i] = -20.89f; + queryVector[i] = -10.5f; + } else { + vector1[i] = 65504.2f; + vector2[i] = 65514.2f; + vector3[i] = 65504.0f; + vector4[i] = 100000000.0f; + queryVector[i] = 25.48f; + } + } + addKnnDoc(testIndex, "1", TEST_FIELD, vector1); addKnnDoc(testIndex, "2", TEST_FIELD, vector2); addKnnDoc(testIndex, "3", TEST_FIELD, vector3); addKnnDoc(testIndex, "4", TEST_FIELD, vector4); - float[] queryVector = { -10.5f, 25.48f }; int k = 4; Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k); List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD); diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index 822f0e2ca1..cd34a1ecc0 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -24,11 +24,6 @@ import java.util.Set; import java.util.function.Function; -import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; - /** * Abstract class for KNN methods. This class provides the common functionality for all KNN methods. * It defines the common attributes and methods that all KNN methods should implement. @@ -116,49 +111,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( protected Function doGetTrainingConfigValidationSetup() { return (trainingConfigValidationInput) -> { - - KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); - KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); - Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - - // validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension - if (knnMethodContext != null && knnMethodConfigContext != null) { - if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) - && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(ENCODER_PARAMETER_PQ_M) != 0) { - builder.valid(false); - return builder.build(); - } else { - builder.valid(true); - } - } - - // validate number of training points should be greater than minimum clustering criteria defined in faiss - if (knnMethodContext != null && trainingVectors != null) { - long minTrainingVectorCount = 1000; - - MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER); - - if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) - && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { - - int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); - int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); - minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); - } - - if (trainingVectors < minTrainingVectorCount) { - builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); - return builder.build(); - } else { - builder.valid(true); - } - } return builder.build(); }; } diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java index 8127a041da..31357cc1fe 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java @@ -182,4 +182,5 @@ protected void validateCompressionConflicts(CompressionLevel originalCompression throw validationException; } } + } diff --git a/src/main/java/org/opensearch/knn/index/engine/Encoder.java b/src/main/java/org/opensearch/knn/index/engine/Encoder.java index f15d0afcf3..76b9d02903 100644 --- a/src/main/java/org/opensearch/knn/index/engine/Encoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/Encoder.java @@ -36,4 +36,14 @@ default String getName() { * return {@link CompressionLevel#NOT_CONFIGURED} */ CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext); + + /** + * Validates config of encoder + * + * @return Validation output of encoder parameters + */ + default TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput) { + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + return builder.build(); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java index 0cbe6cad5d..3c000f4040 100644 --- a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java @@ -20,6 +20,6 @@ @Builder @AllArgsConstructor public class TrainingConfigValidationOutput { - private boolean valid; - private long minTrainingVectorCount; + private Boolean valid; + private Long minTrainingVectorCount; } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 356292678f..107640ba90 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -129,6 +129,33 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m return (MethodComponentContext) object; } + protected String getEncoderName(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext); + if (methodComponentContext == null) { + return null; + } + + return methodComponentContext.getName(); + } + + protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER); + } + + protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) { + return knnMethodContext != null + && knnMethodContext.getMethodComponentContext().getParameters() != null + && knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER); + } + @Override protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) { // While FAISS doesn't directly support cosine similarity, we can leverage the mathematical diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java index a894d8ed63..0ff8d85660 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java @@ -13,6 +13,11 @@ import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; + +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; +import org.opensearch.knn.index.engine.KNNMethodContext; /** * Abstract class for Faiss PQ encoders. This class provides the common logic for product quantization based encoders @@ -89,4 +94,50 @@ public CompressionLevel calculateCompressionLevel( // compression return CompressionLevel.MAX_COMPRESSION_LEVEL; } + + @Override + public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { + KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); + Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); + + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + // validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (knnMethodContext != null && knnMethodConfigContext != null) { + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) + && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(ENCODER_PARAMETER_PQ_M) != 0) { + builder.valid(false); + return builder.build(); + } else { + builder.valid(true); + } + } + + // validate number of training points should be greater than minimum clustering criteria defined in faiss + if (knnMethodContext != null && trainingVectors != null) { + long minTrainingVectorCount = 1000; + + MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER); + + if (encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { + + int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); + minTrainingVectorCount = (long) Math.pow(2, code_size); + } + + if (trainingVectors < minTrainingVectorCount) { + builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); + return builder.build(); + } else { + builder.valid(true); + } + } + + return builder.build(); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index 3386f871c7..b1b0f00631 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -16,12 +16,16 @@ import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION; @@ -124,4 +128,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter() SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) ); } + + @Override + protected Function doGetTrainingConfigValidationSetup() { + return (trainingConfigValidationInput) -> { + + KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + if (isEncoderSpecified(knnMethodContext) == false) { + return builder.build(); + } + Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext)); + if (encoder == null) { + return builder.build(); + } + + return encoder.validateEncoderConfig(trainingConfigValidationInput); + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 5820293921..731829e7a4 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -15,12 +15,16 @@ import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; @@ -150,4 +154,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter() SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) ); } + + @Override + protected Function doGetTrainingConfigValidationSetup() { + return (trainingConfigValidationInput) -> { + + KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + if (isEncoderSpecified(knnMethodContext) == false) { + return builder.build(); + } + Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext)); + if (encoder == null) { + return builder.build(); + } + + return encoder.validateEncoderConfig(trainingConfigValidationInput); + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java index c976a0959b..f7d8642fe7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java @@ -15,6 +15,8 @@ import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; @@ -73,6 +75,9 @@ public ResolvedMethodContext resolveMethod( encoderMap ); + // Validate encoder parameters + validateEncoderConfig(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap); + // Validate that resolved compression doesnt have any conflicts validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel); knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel); @@ -148,6 +153,32 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) { } } + protected void validateEncoderConfig( + KNNMethodContext resolvedKnnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + Map encoderMap + ) { + if (isEncoderSpecified(resolvedKnnMethodContext) == false) { + return; + } + Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext)); + if (encoder == null) { + return; + } + + TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); + + TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig( + inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build() + ); + + if (validationOutput.getValid() != null && !validationOutput.getValid()) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions"); + throw validationException; + } + } + private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) { if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) { return knnMethodConfigContext.getCompressionLevel(); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 5de0405c45..7c28b9ae47 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -154,7 +154,7 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques TrainingConfigValidationOutput validation = validateTrainingConfig.apply( inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build() ); - if (!validation.isValid()) { + if (validation.getValid() != null && !validation.getValid()) { ValidationException exception = new ValidationException(); exception.addValidationError( String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount()) diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index bd2c883477..9906ab490b 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -30,14 +30,10 @@ import org.opensearch.knn.index.engine.EngineResolver; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.TrainingConfigValidationInput; -import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import java.io.IOException; -import java.util.function.Function; /** * Request to train and serialize a model @@ -287,21 +283,6 @@ public ActionRequestValidationException validate() { exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters"); } - KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); - Function validateTrainingConfig = knnLibraryIndexingContext - .getTrainingConfigValidationSetup(); - TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); - TrainingConfigValidationOutput validation = validateTrainingConfig.apply( - inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build() - ); - - // Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension - if (!validation.isValid()) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions"); - } - // Validate training index exists IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex); if (indexMetadata == null) { diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 20249237d4..0b14696fe0 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -796,7 +796,7 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { List efConstructionValues = ImmutableList.of(16, 32, 64, 128); List efSearchValues = ImmutableList.of(16, 32, 64, 128); - int dimension = 2; + int dimension = 128; // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() @@ -830,7 +830,23 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { createKnnIndex(indexName, mapping); assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - Float[] vector = { -10.76f, 65504.2f }; + + Float[] vector = new Float[dimension]; + Float[] vector1 = new Float[dimension]; + Float[] vector2 = new Float[dimension]; + int halfDimension = dimension / 2; + + for (int i = 0; i < dimension; i++) { + if (i < halfDimension) { + vector[i] = -10.76f; + vector1[i] = -65506.84f; + vector2[i] = -65526.4567f; + } else { + vector[i] = 65504.2f; + vector1[i] = 12.56f; + vector2[i] = 65526.4567f; + } + } ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "1", fieldName, vector)); assertTrue( @@ -847,8 +863,6 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { ) ); - Float[] vector1 = { -65506.84f, 12.56f }; - ResponseException ex1 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "2", fieldName, vector1)); assertTrue( ex1.getMessage() @@ -864,8 +878,6 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { ) ); - Float[] vector2 = { -65526.4567f, 65526.4567f }; - ResponseException ex2 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "3", fieldName, vector2)); assertTrue( ex2.getMessage() @@ -893,7 +905,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then List efConstructionValues = ImmutableList.of(16, 32, 64, 128); List efSearchValues = ImmutableList.of(16, 32, 64, 128); - int dimension = 2; + int dimension = 128; // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() @@ -928,16 +940,35 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then createKnnIndex(indexName, mapping); assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - Float[] vector1 = { -65523.76f, 65504.2f }; - Float[] vector2 = { -270.85f, 65514.2f }; - Float[] vector3 = { -150.9f, 65504.0f }; - Float[] vector4 = { -20.89f, 100000000.0f }; + + Float[] vector1 = new Float[dimension]; + Float[] vector2 = new Float[dimension]; + Float[] vector3 = new Float[dimension]; + Float[] vector4 = new Float[dimension]; + float[] queryVector = new float[dimension]; + int halfDimension = dimension / 2; + + for (int i = 0; i < dimension; i++) { + if (i < halfDimension) { + vector1[i] = -65523.76f; + vector2[i] = -270.85f; + vector3[i] = -150.9f; + vector4[i] = -20.89f; + queryVector[i] = -10.5f; + } else { + vector1[i] = 65504.2f; + vector2[i] = 65514.2f; + vector3[i] = 65504.0f; + vector4[i] = 100000000.0f; + queryVector[i] = 25.48f; + } + } + addKnnDoc(indexName, "1", fieldName, vector1); addKnnDoc(indexName, "2", fieldName, vector2); addKnnDoc(indexName, "3", fieldName, vector3); addKnnDoc(indexName, "4", fieldName, vector4); - float[] queryVector = { -10.5f, 25.48f }; int k = 4; Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k); List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java index 3a33736fa3..474caa1c59 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java @@ -269,5 +269,29 @@ public void testResolveMethod_whenInvalid_thenThrow() { ) ); + + Map parameters = Map.of("m", 3); + + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); + + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(10) + .versionCreated(Version.CURRENT) + .compressionLevel(CompressionLevel.x8) + .mode(Mode.ON_DISK) + .build(); + + ValidationException validationException = expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod(knnMethodContext, knnMethodConfigContext, false, SpaceType.INNER_PRODUCT) + + ); + + assertTrue( + validationException.getMessage().contains("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions") + ); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index fdffc91d02..5bf3337aaf 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -626,56 +626,6 @@ public void testValidation_invalid_descriptionToLong() { assertTrue(validationErrors.get(0).contains("Description exceeds limit")); } - public void testValidation_invalid_mNotDivisibleByDimension() { - - // Setup the training request - String modelId = "test-model-id"; - int dimension = 10; - String trainingIndex = "test-training-index"; - String trainingField = "test-training-field"; - String trainingFieldModeId = "training-field-model-id"; - - Map parameters = Map.of("m", 3); - - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters); - final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); - - TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null, - VectorDataType.DEFAULT, - Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED - ); - - // Mock the model dao to return metadata for modelId to recognize it is a duplicate - ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); - when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); - - ModelDao modelDao = mock(ModelDao.class); - when(modelDao.getMetadata(modelId)).thenReturn(null); - when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata); - - // Cluster service that wont produce validation exception - ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); - - // Initialize static components with the mocks - TrainingModelRequest.initialize(modelDao, clusterService); - - // Test that validation produces m not divisible by vector dimension error message - ActionRequestValidationException exception = trainingModelRequest.validate(); - assertNotNull(exception); - List validationErrors = exception.validationErrors(); - logger.error("Validation errors " + validationErrors); - assertEquals(2, validationErrors.size()); - assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M")); - } - public void testValidation_valid_trainingIndexBuiltFromMethod() { // This cluster service will result in no validation exceptions