From 2cc1523b14d813a48a8bf11cbf786d677d4dc277 Mon Sep 17 00:00:00 2001 From: AnnTian Shao Date: Wed, 5 Feb 2025 13:26:46 -0800 Subject: [PATCH] fixes and moved all validation to encoder Signed-off-by: AnnTian Shao --- .../opensearch/knn/index/engine/Encoder.java | 5 ++- .../engine/faiss/AbstractFaissMethod.java | 27 ++++++++++++++ .../engine/faiss/AbstractFaissPQEncoder.java | 25 +++++++++++++ .../index/engine/faiss/FaissFlatEncoder.java | 9 ----- .../index/engine/faiss/FaissHNSWMethod.java | 35 +++++-------------- .../index/engine/faiss/FaissIVFMethod.java | 34 +++++------------- .../index/engine/faiss/FaissSQEncoder.java | 8 ----- .../index/engine/faiss/QFrameBitEncoder.java | 26 -------------- .../index/engine/lucene/LuceneSQEncoder.java | 8 ----- .../engine/AbstractMethodResolverTests.java | 6 ---- 10 files changed, 72 insertions(+), 111 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/engine/Encoder.java b/src/main/java/org/opensearch/knn/index/engine/Encoder.java index ae12c4f1ab..76b9d02903 100644 --- a/src/main/java/org/opensearch/knn/index/engine/Encoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/Encoder.java @@ -42,5 +42,8 @@ default String getName() { * * @return Validation output of encoder parameters */ - TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput); + default TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput) { + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + return builder.build(); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 356292678f..107640ba90 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -129,6 +129,33 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m return (MethodComponentContext) object; } + protected String getEncoderName(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext); + if (methodComponentContext == null) { + return null; + } + + return methodComponentContext.getName(); + } + + protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER); + } + + protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) { + return knnMethodContext != null + && knnMethodContext.getMethodComponentContext().getParameters() != null + && knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER); + } + @Override protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) { // While FAISS doesn't directly support cosine similarity, we can leverage the mathematical diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java index 80eae98374..0ff8d85660 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java @@ -13,6 +13,7 @@ import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import org.opensearch.knn.index.engine.TrainingConfigValidationInput; import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; @@ -98,6 +99,7 @@ public CompressionLevel calculateCompressionLevel( public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); + Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); @@ -113,6 +115,29 @@ public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValida builder.valid(true); } } + + // validate number of training points should be greater than minimum clustering criteria defined in faiss + if (knnMethodContext != null && trainingVectors != null) { + long minTrainingVectorCount = 1000; + + MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER); + + if (encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { + + int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); + minTrainingVectorCount = (long) Math.pow(2, code_size); + } + + if (trainingVectors < minTrainingVectorCount) { + builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); + return builder.build(); + } else { + builder.valid(true); + } + } + return builder.build(); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java index fc84c9a711..f7d4342fc8 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java @@ -16,9 +16,6 @@ import java.util.Set; -import org.opensearch.knn.index.engine.TrainingConfigValidationInput; -import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; - /** * Flat faiss encoder. Flat encoding means that it does nothing. It needs an encoder, though, because it * is used in generating the index description. @@ -55,10 +52,4 @@ public CompressionLevel calculateCompressionLevel( ) { return CompressionLevel.x1; } - - @Override - public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - return builder.build(); - } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index c8e9fc4b51..b1b0f00631 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -34,8 +34,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; /** * Faiss HNSW method implementation @@ -136,34 +134,17 @@ protected Function { KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); - Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - // validate number of training points should be greater than minimum clustering criteria defined in faiss - if (knnMethodContext != null && trainingVectors != null) { - long minTrainingVectorCount = 1000; - - MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER); - - if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) - && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { - - int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); - int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); - minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); - } - - if (trainingVectors < minTrainingVectorCount) { - builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); - return builder.build(); - } else { - builder.valid(true); - } + if (isEncoderSpecified(knnMethodContext) == false) { + return builder.build(); + } + Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext)); + if (encoder == null) { + return builder.build(); } - return builder.build(); + + return encoder.validateEncoderConfig(trainingConfigValidationInput); }; } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 94c04155b2..731829e7a4 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -37,7 +37,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_LIMIT; -import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; /** * Faiss ivf implementation @@ -161,34 +160,17 @@ protected Function { KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); - Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - // validate number of training points should be greater than minimum clustering criteria defined in faiss - if (knnMethodContext != null && trainingVectors != null) { - long minTrainingVectorCount = 1000; - - MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER); - - if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) - && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { - - int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); - int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); - minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); - } - - if (trainingVectors < minTrainingVectorCount) { - builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); - return builder.build(); - } else { - builder.valid(true); - } + if (isEncoderSpecified(knnMethodContext) == false) { + return builder.build(); } - return builder.build(); + Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext)); + if (encoder == null) { + return builder.build(); + } + + return encoder.validateEncoderConfig(trainingConfigValidationInput); }; } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java index 24f72849ee..cd7e1e5f38 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java @@ -13,8 +13,6 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; import org.opensearch.knn.index.mapper.CompressionLevel; -import org.opensearch.knn.index.engine.TrainingConfigValidationInput; -import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import java.util.Objects; import java.util.Set; @@ -63,10 +61,4 @@ public CompressionLevel calculateCompressionLevel( // TODO: Hard code for now return CompressionLevel.x2; } - - @Override - public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - return builder.build(); - } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java index 066a7fea3d..2292dc3ccf 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java @@ -17,9 +17,6 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.index.engine.TrainingConfigValidationInput; -import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; -import org.opensearch.knn.index.engine.KNNMethodContext; import java.util.HashMap; import java.util.Locale; @@ -27,7 +24,6 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_FLAT_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; /** * Quantization framework binary encoder, @@ -114,26 +110,4 @@ public CompressionLevel calculateCompressionLevel( // Validation will ensure that only 1 of the supported bit count will be selected. return CompressionLevel.x8; } - - @Override - public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { - KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); - KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); - - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - - // validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension - if (knnMethodContext != null && knnMethodConfigContext != null) { - if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) - && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(ENCODER_PARAMETER_PQ_M) != 0) { - builder.valid(false); - return builder.build(); - } else { - builder.valid(true); - } - } - return builder.build(); - } } diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java index 4bc41421d3..6bd16ebee8 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java @@ -13,8 +13,6 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; import org.opensearch.knn.index.mapper.CompressionLevel; -import org.opensearch.knn.index.engine.TrainingConfigValidationInput; -import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import java.util.List; import java.util.Set; @@ -63,10 +61,4 @@ public CompressionLevel calculateCompressionLevel( // Hard coding to 4x for now, given thats all that is supported. return CompressionLevel.x4; } - - @Override - public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - return builder.build(); - } } diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java index dd9cdb5996..f214592461 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java @@ -47,12 +47,6 @@ public CompressionLevel calculateCompressionLevel( ) { return DEFAULT_COMPRESSION; } - - @Override - public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) { - TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); - return builder.build(); - } }; private final static Map ENCODER_MAP = Map.of(ENCODER_NAME, TEST_ENCODER);